├── gsmvi
├── __init__.py
├── initializers.py
├── monitors.py
├── advi.py
├── gsm_numpy.py
├── gsm.py
└── bam.py
├── pyproject.toml
├── LICENSE
├── examples
├── example_gsm.py
├── example_gsm_numpy.py
├── example_advi.py
├── example_bam.py
└── example_initializers.py
├── .gitignore
└── README.md
/gsmvi/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/gsmvi/initializers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.optimize import minimize
3 |
4 |
5 | def lbfgs_init(x0, lp, lp_g=None, maxiter=1000, maxfun=1000):
6 |
7 | f = lambda x: -lp(x)
8 | if lp_g is not None:
9 | f_g = lambda x: -lp_g(x)
10 | else:
11 | f_g = None
12 | res = minimize(f, x0, method='L-BFGS-B', jac=f_g, \
13 | options={"maxiter":maxiter, "maxfun":maxfun})
14 |
15 | mu = res.x
16 | cov = res.hess_inv.todense()
17 | return mu, cov, res
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "gsmvi"
7 | version = "0.0.7"
8 | authors = [
9 | { name="Chirag Modi", email="modichirag92@gmail.com" },
10 | ]
11 | description = "Implementation of Gaussian score matching for variational inference (arXiv:2307.07849)"
12 | readme = "README.md"
13 | requires-python = ">=3.8"
14 | classifiers = [
15 | "Programming Language :: Python :: 3",
16 | "License :: OSI Approved :: MIT License",
17 | "Operating System :: OS Independent",
18 | ]
19 | dependencies = [
20 | ]
21 |
22 |
23 | [project.optional-dependencies]
24 | full = [
25 | "jax",
26 | "jaxlib",
27 | "numpyro",
28 | "optax",
29 | "scipy"]
30 |
31 | [project.urls]
32 | Homepage = "https://github.com/modichirag/GSM-VI/tree/package"
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Chirag Modi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/examples/example_gsm.py:
--------------------------------------------------------------------------------
1 | ## A basic example for fitting a target Multivariate Gaussian distribution with GSM updates
2 |
3 | ## Uncomment the following lines if you run into memory issues with JAX
4 | # import os
5 | # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
6 | # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
7 | # os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
8 |
9 | import numpy as np
10 | import jax.numpy as jnp
11 | from jax import jit, grad, random
12 | import numpyro.distributions as dist
13 |
14 | from gsmvi.gsm import GSM
15 |
16 | #####
17 | def setup_model(D=10):
18 |
19 | # setup a Gaussian target distribution
20 | mean = np.random.random(D)
21 | L = np.random.normal(size = D**2).reshape(D, D)
22 | cov = np.matmul(L, L.T) + np.eye(D)*1e-3
23 | model = dist.MultivariateNormal(loc=mean, covariance_matrix=cov)
24 | return model
25 |
26 |
27 | if __name__=="__main__":
28 |
29 | ###
30 | # setup a toy Gaussia model and extracet score needed for GSM
31 | D = 10
32 | model = setup_model(D=D)
33 | mean, cov = model.loc, model.covariance_matrix
34 | lp = jit(lambda x: jnp.sum(model.log_prob(x)))
35 | lp_g = jit(grad(lp, argnums=0))
36 |
37 | ###
38 | # Fit with GSM
39 | niter = 500
40 | key = random.PRNGKey(99)
41 | gsm = GSM(D=D, lp=lp, lp_g=lp_g)
42 | mean_fit, cov_fit = gsm.fit(key, niter=niter)
43 |
44 | print("\nTrue mean : ", mean)
45 | print("Fit mean : ", mean_fit)
46 |
--------------------------------------------------------------------------------
/examples/example_gsm_numpy.py:
--------------------------------------------------------------------------------
1 | ## A basic example for fitting a target Multivariate Gaussian distribution with GSM updates
2 |
3 | import numpy as np
4 |
5 | from gsmvi.gsm_numpy import GSM
6 |
7 | #####
8 | def setup_model(D):
9 |
10 | # setup a Gaussian target distribution
11 | mean = np.random.random(D)
12 | L = np.random.normal(size = D**2).reshape(D, D)
13 | cov = np.matmul(L, L.T) + np.eye(D)*1e-3
14 | icov = np.linalg.inv(cov)
15 |
16 | # functions for log_prob and score. These are to be supplied by the user
17 | def lp(x):
18 | assert len(x.shape) == 2
19 | lp = 0
20 | for i in range(x.shape[0]):
21 | lp += -0.5 * np.dot(np.dot(mean - x[i], icov), mean - x[i])
22 | return lp
23 |
24 | def lp_g(x):
25 | assert len(x.shape) == 2
26 | lp_g = []
27 | for i in range(x.shape[0]):
28 | lp_g.append( -1. * np.dot(icov, x[i] - mean))
29 | return np.array(lp_g)
30 |
31 | return mean, cov, lp, lp_g
32 |
33 |
34 | if __name__=="__main__":
35 |
36 | ###
37 | # setup a toy Gaussia model and extracet score needed for GSM
38 | D = 5
39 | mean, cov, lp, lp_g = setup_model(D=D)
40 |
41 | ###
42 | # Fit with GSM
43 | niter = 500
44 | key = 99
45 | gsm = GSM(D=D, lp=lp, lp_g=lp_g)
46 | mean_fit, cov_fit = gsm.fit(key, niter=niter)
47 |
48 | print("\nTrue mean : ", mean)
49 | print("Fit mean : ", mean_fit)
50 |
--------------------------------------------------------------------------------
/examples/example_advi.py:
--------------------------------------------------------------------------------
1 | ## A basic example for fitting a target Multivariate Gaussian distribution with GSM updates
2 |
3 | ## Uncomment the following lines if you run into memory issues with JAX
4 | # import os
5 | # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
6 | # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
7 | # os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
8 |
9 | import numpy as np
10 | import jax.numpy as jnp
11 | from jax import jit, grad, random
12 | import optax
13 | import numpyro.distributions as dist
14 |
15 | from gsmvi.advi import ADVI
16 |
17 | def setup_model(D=10):
18 |
19 | # setup a Gaussian target distribution
20 | mean = np.random.random(D)
21 | L = np.random.normal(size = D**2).reshape(D, D)
22 | cov = np.matmul(L, L.T) + np.eye(D)*1e-3
23 | model = dist.MultivariateNormal(loc=mean, covariance_matrix=cov)
24 | return model
25 |
26 |
27 | if __name__=="__main__":
28 |
29 | ###
30 | # setup a toy Gaussia model and extract log-prob
31 | D = 4
32 | model = setup_model(D=D)
33 | mean, cov = model.loc, model.covariance_matrix
34 | lp = jit(lambda x: jnp.sum(model.log_prob(x)))
35 |
36 | ###
37 | # Fit with advi
38 | niter = 10000
39 | lr = 1e-2
40 | batch_size = 16
41 | advi = ADVI(D=D, lp=lp)
42 | key = random.PRNGKey(99)
43 | opt = optax.adam(learning_rate=lr)
44 | mean_fit, cov_fit, losses = advi.fit(key, opt, batch_size=batch_size, niter=niter)
45 |
46 | print("\nTrue mean : ", mean)
47 | print("Fit mean : ", mean_fit)
48 |
49 |
--------------------------------------------------------------------------------
/examples/example_bam.py:
--------------------------------------------------------------------------------
1 | ## Most basic example for fitting a target Multivariate Gaussian distribution with BaM updates
2 |
3 | import numpy as np
4 | import os
5 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
6 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
7 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
8 | #os.environ['CUDA_VISIBLE_DEVICES'] = '' # To enable CPU backend
9 |
10 | from jax.lib import xla_bridge
11 | print("Device : ", xla_bridge.get_backend().platform)
12 |
13 | # enable 16 bit precision for jax
14 | from jax import config
15 | config.update("jax_enable_x64", True)
16 |
17 | import jax
18 | import jax.numpy as jnp
19 | from jax import jit, grad, random
20 |
21 | import numpyro
22 | import numpyro.distributions as dist
23 |
24 | # Import BaM
25 | from gsmvi.bam import BaM, Regularizers
26 | from gsmvi.monitors import KLMonitor
27 | #####
28 |
29 |
30 | #####
31 | def setup_model(D=10):
32 |
33 | # setup a Gaussian target distribution
34 | mean = np.random.random(D)
35 | L = np.random.normal(size = D**2).reshape(D, D)
36 | cov = np.matmul(L, L.T) + np.eye(D)*1e-3
37 | model = dist.MultivariateNormal(loc=mean, covariance_matrix=cov)
38 | lp = jit(lambda x: jnp.sum(model.log_prob(x)))
39 | lp_g = jit(grad(lp, argnums=0))
40 |
41 | return model, mean, cov, lp, lp_g
42 |
43 |
44 |
45 | if __name__=="__main__":
46 |
47 | D = 5
48 | model, mean, cov, lp, lp_g = setup_model(D=D)
49 | ref_samples = model.sample(random.PRNGKey(99), (1000,))
50 |
51 | niter = 100
52 | batch_size = 2
53 | regularizer = Regularizers()
54 |
55 | # Example regularization functions
56 | # regf = regularizer.constant(100)
57 | # regf = regularizer.linear(100)
58 | func = lambda i : 100/(1+i)
59 | regf = regularizer.custom(func)
60 |
61 |
62 | bam = BaM(D=D, lp=lp, lp_g=lp_g, use_lowrank=True, jit_compile=True)
63 | key = random.PRNGKey(99)
64 | mean_fit, cov_fit = bam.fit(key, regf=regf, niter=niter, batch_size=batch_size)
65 |
66 | print()
67 | print("True mean : ", mean)
68 | print("Fit mean : ", mean_fit)
69 | print()
70 | print("Check mean fit")
71 | print(np.allclose(mean, mean_fit))
72 |
73 | print()
74 | print("Check cov fit")
75 | print(np.allclose(cov, cov_fit))
76 | print(cov)
77 | print()
78 | print(cov_fit)
79 |
--------------------------------------------------------------------------------
/examples/example_initializers.py:
--------------------------------------------------------------------------------
1 | ## Example for fitting a target Multivariate Gaussian distribution with GSM and ADVI
2 | ## Variational distribution is initialized with LBFGS fit for the mean and covariance
3 | ## The progress is monitored with a Monitor class.
4 |
5 | ## Uncomment the following lines if you run into memory issues with JAX
6 | # import os
7 | # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
8 | # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
9 | # os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
10 |
11 | import numpy as np
12 | import matplotlib.pyplot as plt
13 | import jax.numpy as jnp
14 | from jax import jit, grad, random
15 | import numpyro.distributions as dist
16 | import optax
17 |
18 | # enable 16 bit precision for jax required for lbfgs initializer
19 | from jax import config
20 | config.update("jax_enable_x64", True)
21 |
22 | from gsmvi.gsm import GSM
23 | from gsmvi.advi import ADVI
24 | from gsmvi.initializers import lbfgs_init
25 | from gsmvi.monitors import KLMonitor
26 |
27 | #####
28 | def setup_model(D=10):
29 |
30 | # setup a Gaussian target distribution
31 | mean = np.random.random(D)
32 | L = np.random.normal(size = D**2).reshape(D, D)
33 | cov = np.matmul(L, L.T) + np.eye(D)*1e-3
34 | model = dist.MultivariateNormal(loc=mean, covariance_matrix=cov)
35 | return model
36 |
37 |
38 | #
39 | def gsm_fit(D, lp, lp_g, mean_init, cov_init, lbfgs_res):
40 | print("Now fit with GSM")
41 | niter = 500
42 | batch_size = 1
43 | key = random.PRNGKey(99)
44 | monitor = KLMonitor(batch_size_kl=32, checkpoint=10, \
45 | offset_evals=lbfgs_res.nfev) #note the offset number of evals
46 |
47 | gsm = GSM(D=D, lp=lp, lp_g=lp_g)
48 | mean_fit, cov_fit = gsm.fit(key, mean=mean_init, cov=cov_init, niter=niter, batch_size=batch_size, monitor=monitor)
49 | return mean_fit, cov_fit, monitor
50 |
51 |
52 | #
53 | def advi_fit(D, lp, lp_g, mean_init, cov_init, lbfgs_res):
54 | print("\nNow fit with ADVI")
55 | niter = 500
56 | lr = 1e-2
57 | batch_size = 1
58 | key = random.PRNGKey(99)
59 | opt = optax.adam(learning_rate=lr)
60 | monitor = KLMonitor(batch_size_kl=32, checkpoint=10, \
61 | offset_evals=lbfgs_res.nfev) #note the offset number of evals
62 |
63 | advi = ADVI(D=D, lp=lp)
64 | mean_fit, cov_fit, losses = advi.fit(key, mean=mean_init, cov=cov_init, opt=opt, batch_size=batch_size, niter=niter, monitor=monitor)
65 | return mean_fit, cov_fit, monitor
66 |
67 |
68 |
69 | if __name__=="__main__":
70 |
71 | ###
72 | # setup a toy Gaussia model and extracet score needed for GSM
73 | D = 16
74 | model = setup_model(D=D)
75 | mean, cov = model.loc, model.covariance_matrix
76 | lp = jit(lambda x: jnp.sum(model.log_prob(x)))
77 | lp_g = jit(grad(lp, argnums=0))
78 |
79 | ###
80 | print("Initialize with LBFGS")
81 | mean_init = np.ones(D) # setup gsm with initilization from LBFGS fit
82 | mean_init, cov_init, lbfgs_res = lbfgs_init(mean_init, lp, lp_g)
83 | print(f'LBFGS fit: \n{lbfgs_res}\n')
84 |
85 | mean_gsm, cov_gsm, monitor_gsm = gsm_fit(D, lp, lp_g, mean_init, cov_init, lbfgs_res)
86 | mean_advi, cov_advi, monitor_advi = advi_fit(D, lp, lp_g, mean_init, cov_init, lbfgs_res)
87 |
88 | # Check that the output is correct
89 | print("\nTrue mean : ", mean)
90 | print("Fit gsm : ", mean_gsm)
91 | print("Fit advi : ", mean_advi)
92 |
93 |
94 | # Check that the KL divergence decreases
95 | plt.plot(monitor_gsm.nevals, monitor_gsm.rkl, label='GSM')
96 | plt.plot(monitor_advi.nevals, monitor_advi.rkl, label='ADVI')
97 | plt.legend()
98 | plt.xlabel("Iteration")
99 | plt.ylabel("Reverse KL")
100 | plt.savefig("monitor_kl.png")
101 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/gsmvi/monitors.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from dataclasses import dataclass
3 | from functools import partial
4 |
5 | import jax.numpy as jnp
6 | from jax import jit, grad, random
7 | from numpyro.distributions import MultivariateNormal
8 |
9 |
10 | def reverse_kl(samples, lpq, lpp):
11 | logl = np.sum(lpp(samples))
12 | logq = np.sum(lpq(samples))
13 | rkl = logq - logl
14 | rkl /= samples.shape[0]
15 | return rkl
16 |
17 | def forward_kl(samples, lpq, lpp):
18 | logl = np.sum(lpp(samples))
19 | logq = np.sum(lpq(samples))
20 | fkl = logl - logq
21 | fkl /= samples.shape[0]
22 | return fkl
23 |
24 | @partial(jit, static_argnums=(3))
25 | def reverse_kl_jit(samples, mu, cov, lp):
26 | q = MultivariateNormal(mu, cov)
27 | logq = jnp.sum(q.log_prob(samples))
28 | logl = jnp.sum(lp(samples))
29 | rkl = logq - logl
30 | rkl /= samples.shape[0]
31 | return rkl
32 |
33 | @partial(jit, static_argnums=(3))
34 | def forward_kl_jit(samples, mu, cov, lp):
35 | q = MultivariateNormal(mu, cov)
36 | logq = jnp.sum(q.log_prob(samples))
37 | logl = jnp.sum(lp(samples))
38 | fkl = logl- logq
39 | fkl /= samples.shape[0]
40 | return fkl
41 |
42 |
43 | @dataclass
44 | class KLMonitor():
45 | """
46 | Class to monitor KL divergence during optimization for VI
47 |
48 | Inputs:
49 |
50 | batch_size_kl: (int) Number of samples to use to estimate KL divergence
51 | checkpoint: (int) Number of iterations after which to run monitor
52 | offset_evals: (int) Value with which to offset number of gradient evaluatoins
53 | Used to account for gradient evaluations done in warmup or initilization
54 | ref_samples: Optional, samples from the target distribution.
55 | If provided, also track forward KL divergence
56 | """
57 |
58 | batch_size_kl : int = 8
59 | checkpoint : int = 20
60 | offset_evals : int = 0
61 | ref_samples : np.array = None
62 |
63 | def __post_init__(self):
64 |
65 | self.rkl = []
66 | self.fkl = []
67 | self.nevals = []
68 |
69 | def reset(self,
70 | batch_size_kl=None,
71 | checkpoint=None,
72 | offset_evals=None,
73 | ref_samples=None ):
74 | self.nevals = []
75 | self.rkl = []
76 | self.fkl = []
77 | if batch_size_kl is not None: self.batch_size_kl = batch_size_kl
78 | if checkpoint is not None: self.checkpoint = checkpoint
79 | if offset_evals is not None: self.offset_evals = offset_evals
80 | if ref_samples is not None: self.ref_samples = ref_samples
81 | print('offset evals reset to : ', self.offset_evals)
82 |
83 | def __call__(self, i, params, lp, key, nevals=1):
84 | """
85 | Main function to monitor reverse (and forward) KL divergence over iterations.
86 |
87 | Inputs:
88 |
89 | i: (int) iteration number
90 | params: (tuple; (mean, cov)) Current estimate of mean and covariance matrix
91 | lp: Function to evaluate target log-probability
92 | key: Random number generator key (jax.random.PRNGKey)
93 | nevals: (int) Number of gradient evaluations SINCE the last call of the monitor function
94 |
95 | Returns:
96 | key : New key for generation random number
97 | """
98 |
99 | #
100 | mu, cov = params
101 | key, key_sample = random.split(key)
102 | np.random.seed(key_sample[0])
103 |
104 |
105 | try:
106 | qsamples = np.random.multivariate_normal(mean=mu, cov=cov, size=self.batch_size_kl)
107 | q = MultivariateNormal(loc=mu, covariance_matrix=cov)
108 | self.rkl.append(reverse_kl(qsamples, q.log_prob, lp))
109 |
110 | if self.ref_samples is not None:
111 | idx = np.random.permutation(self.ref_samples.shape[0])[:self.batch_size_kl]
112 | psamples = self.ref_samples[idx]
113 | self.fkl.append(forward_kl(psamples, q.log_prob, lp))
114 | else:
115 | self.fkl.append(np.NaN)
116 |
117 | except Exception as e:
118 | print(f"Exception occured in monitor : {e}.\nAppending NaN")
119 | self.rkl.append(np.NaN)
120 | self.fkl.append(np.NaN)
121 |
122 | self.nevals.append(self.offset_evals + nevals)
123 | self.offset_evals = self.nevals[-1]
124 |
125 | return key
126 |
127 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GSM-VI: Gaussian score-based variational inference algorithms
2 |
3 | This repository provides a Python implementation of two score-based (black-box)
4 | variational inference algorithms:
5 | 1. GSM-VI, which is described in the NeurIPS 2023 paper https://arxiv.org/abs/2307.07849.
6 | 2. Batch and match (BaM), which is described in the ICML 2024 paper https://arxiv.org/pdf/2402.14758.
7 |
8 | We describe each of these in the following two sections.
9 |
10 | ## Variational Inference (VI) with Gaussian Score Matching (GSM) (NeurIPS 2023)
11 |
12 | GSM-VI fits a multivariate Gasussian distribution with dense covaraince matrix to the target distribution
13 | by score matching. It only requires access to the score function i.e. the gradient of the target log-probability
14 | distribution and implements analytic updates for the variational parameters (mean and covariance matrix).
15 |
16 | ### Installation:
17 | The code is available on `PyPI`
18 | ```
19 | pip install gsmvi
20 | ```
21 |
22 | ### Usage
23 | The simplest version of the algorithm is written in numpy.
24 | The following is the minimal code to use GSM to fit the parameters `x` of a `model` given its `log_prob` and `log_prob_grad` functions.
25 | See `example/example_gsm_numpy.py` for a full example.
26 | ```
27 | dimensions = D
28 | def log_prob(x):
29 | # return log_prbability at sample x
30 | ...
31 |
32 | def log_prob_grad(x):
33 | # return the score fuction i.e. the gradient of log_prbability at sample x
34 | ...
35 |
36 | from gsmvi.gsm_numpy import GSM
37 | gsm = GSM(D=D, lp=log_prob, lp_g=log_prob_grad)
38 | random_seed = 99
39 | number_of_iterations = 500
40 | mean_fit, cov_fit = gsm.fit(key=random_seed, niter=number_of_iterations)
41 | ```
42 |
43 | A more efficient version of the algorithm is implemented in Jax where it can benefit from jit compilation. The basic signature stays the same.
44 | See `example/example_gsm.py` for a full example.
45 | ```
46 | dimensions = D
47 | model = setup_model(D=D) # Ths example sets up a numpyro model which has log_prob attribute implemented
48 | lp = jit(lambda x: jnp.sum(model.log_prob(x)))
49 | lp_g = jit(grad(lp, argnums=0))
50 |
51 | from gsmvi.gsm import GSM
52 | gsm = GSM(D=D, lp=lp, lp_g=lp_g)
53 | mean_fit, cov_fit = gsm.fit(key=random.PRNGKey(99), niter=500)
54 | ```
55 |
56 | #### Other utilities:
57 | - For comparison, we also provide implementation of ADVI algorithm (https://arxiv.org/abs/1603.00788),
58 | another common approach to fit a multivariate Gaussian variational distribution which maximizes ELBO.
59 | - We provide LBFGS initilization for the variational distribution which can be used with GSM and ADVI.
60 | - We also provide a Monitor class to monitor the KL divergence over iterations as the algorithms progress.
61 |
62 | ### Code Dependencies
63 | The vanilla code is written in python3 and does not have any dependencies.
64 |
65 | #### Optional dependencies
66 | These will not be installed with the package and should be installed by user depending on the use-case.
67 |
68 | The Jax version of the code requires `jax` and `jaxlib`.
69 | The target distributions in example files other than example_gsm_numpy.py are implemented in `numpyro`.
70 | ADVI algorithm uses `optax` for maximizing ELBO.
71 | LBFGS initialization for initializing variational distributions uses `scipy`.
72 |
73 | ### Starting point
74 | We provide simple examples in `examples/` folder to fit a target multivariate Gaussian distribution with GSM and ADVI.
75 | ```
76 | cd examples
77 | python3 example_gsm_numpy.py # vanilla example in numpy, no dependencies
78 | python3 example_gsm.py # jax version, requires jax and numpyro
79 | python3 example_advi.py # jax version, requires jax, numpyro and optax
80 | ```
81 | An example on how to use the Monitor class and LBFGS initialization is in `examples/example_initializers.py`
82 | ```
83 | cd examples
84 | python3 example_initializers.py # jax version, requires jax, numpyro, optax and scipy
85 | ```
86 |
87 | ## Batch and match: black-box variational inference with a score-based divergence (ICML 2024)
88 |
89 | [Batch and match (BaM)](https://arxiv.org/pdf/2402.14758) also fits a full covariance multivariate Gaussian and recovers (a version of) GSM as a special case.
90 | In the BaM algorithm, a score-based divergence is minimized.
91 | The code is set up similarly to the GSM code. Currently, it is not yet available
92 | on `PyPI`.
93 |
94 | To install, run
95 | ```
96 | pip install -e .
97 | ```
98 |
99 |
100 | The example usage code is in `examples/example_bam.py':
101 |
102 | ```
103 | cd examples
104 | python3 example_bam.py # jax version, requires jax and numpyro
105 | ```
106 |
107 | Note that this installation approach also includes the GSM and ADVI code
108 | examples above.
109 |
110 |
--------------------------------------------------------------------------------
/gsmvi/advi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import jax
3 | import jax.numpy as jnp
4 | from jax import jit, grad, random
5 | from numpyro.distributions import MultivariateNormal
6 | import optax
7 |
8 | class ADVI():
9 | """
10 | Class for fitting a multivariate Gaussian distribution with dense covariance matrix
11 | by maximizing ELBO.
12 | """
13 |
14 | def __init__(self, D, lp):
15 | """
16 | Inputs:
17 | D: (int) Dimensionality (number) of parameters.
18 | lp : Function to evaluate target log-probability distribution
19 | whose gradient can be evaluated with jax.grad(lp)
20 | """
21 | self.D = D
22 | self.lp = lp
23 | self.idx_tril = jnp.stack(jnp.tril_indices(D)).T
24 |
25 | def scales_to_cov(self, scales):
26 | scale_tril = jnp.zeros((self.D, self.D))
27 | scale_tril = scale_tril.at[self.idx_tril[:, 0], self.idx_tril[:, 1]].set(scales)
28 | cov = np.matmul(scale_tril, scale_tril.T)
29 | return cov
30 |
31 | def neg_elbo(self, params, key, batch_size):
32 | """
33 | Internal function to evaluate negative-ELBO which is the loss function
34 | """
35 | loc, scales = params
36 | scale_tril = jnp.zeros((self.D, self.D))
37 | scale_tril = scale_tril.at[self.idx_tril[:, 0], self.idx_tril[:, 1]].set(scales)
38 | q = MultivariateNormal(loc=loc, scale_tril=scale_tril)
39 | #
40 | samples = q.sample(key, (batch_size,))
41 | logl = jnp.sum(self.lp(samples))
42 | logq = jnp.sum(q.log_prob(samples))
43 | elbo = logl - logq
44 | negelbo = -1. * elbo
45 | return negelbo
46 |
47 | def fit(self, key, opt, mean=None, cov=None, batch_size=8, niter=1000, nprint=10, monitor=None):
48 | """
49 | Main function to fit a multivariate Gaussian distribution to the target
50 |
51 | Inputs:
52 | key: Random number generator key (jax.random.PRNGKey)
53 | mean : Optional, initial value of the mean. Expected None or array of size D
54 | cov : Optional, initial value of the covariance matrix. Expected None or array of size DxD
55 | batch_size : Optional, int. Number of samples to match scores for at every iteration
56 | niter : Optional, int. Total number of iterations
57 | nprint : Optional, int. Number of iterations after which to print logs
58 | monitor : Optional. Function to monitor the progress and track different statistics for diagnostics.
59 | Function call should take the input tuple (iteration number, [mean, cov], lp, key, number of grad evals).
60 | Example of monitor class is provided in utils/monitors.py
61 | Returns:
62 | mu : Array of shape D, fit of the mean
63 | cov : Array of shape DxD, fit of the covariance matrix
64 | """
65 |
66 | lossf = jit(self.neg_elbo, static_argnums=(2))
67 |
68 | @jit
69 | def opt_step(params, opt_state, key):
70 | loss, grads = jax.value_and_grad(lossf, argnums=0)(params, key, batch_size=batch_size)
71 | updates, opt_state = opt.update(grads, opt_state, params)
72 | params = optax.apply_updates(params, updates)
73 | return params, opt_state, loss
74 |
75 | if mean is None:
76 | mean = jnp.zeros(self.D)
77 | if cov is None:
78 | cov = np.identity(self.D)
79 |
80 | # Optimization is done on unconstrained Cholesky factors of covariance matrix
81 | L = np.linalg.cholesky(cov)
82 | scales = jnp.array(L[np.tril_indices(self.D)])
83 | params = (mean, scales)
84 |
85 | # run optimization
86 | opt_state = opt.init(params)
87 | losses = []
88 | nevals = 1
89 |
90 | for i in range(niter + 1):
91 | if(i%(niter//nprint)==0):
92 | print(f'Iteration {i} of {niter}')
93 | if monitor is not None:
94 | if (i%monitor.checkpoint) == 0:
95 | mean = params[0]
96 | cov = self.scales_to_cov( params[1]*1.)
97 | monitor(i, [mean, cov], self.lp, key, nevals=nevals)
98 | nevals = 0
99 |
100 | params, opt_state, loss = opt_step(params, opt_state, key)
101 | key, _ = random.split(key)
102 | losses.append(loss)
103 | nevals += batch_size
104 |
105 |
106 | # Convert back to mean and covariance matrix
107 | mean = params[0]
108 | cov = self.scales_to_cov( params[1]*1.)
109 | if monitor is not None:
110 | monitor(i, [mean, cov], self.lp, key, nevals=nevals)
111 |
112 | return mean, cov, losses
113 |
--------------------------------------------------------------------------------
/gsmvi/gsm_numpy.py:
--------------------------------------------------------------------------------
1 | ## Pure numpy implementation of GSM updates.
2 | import numpy as np
3 |
4 | def _gsm_update_single(sample, v, mu0, S0):
5 | '''returns GSM update to mean and covariance matrix for a single sample
6 | '''
7 | S0v = np.matmul(S0, v)
8 | vSv = np.matmul(v, S0v)
9 | mu_v = np.matmul((mu0 - sample), v)
10 | rho = 0.5 * np.sqrt(1 + 4*(vSv + mu_v**2)) - 0.5
11 | eps0 = S0v - mu0 + sample
12 |
13 | #mu update
14 | mu_vT = np.outer((mu0 - sample), v)
15 | den = 1 + rho + mu_v
16 | I = np.eye(sample.shape[0])
17 | mu_update = 1/(1 + rho) * np.matmul(( I - mu_vT / den), eps0)
18 | mu = mu0 + mu_update
19 |
20 | #S update
21 | Supdate_0 = np.outer((mu0-sample), (mu0-sample))
22 | Supdate_1 = np.outer((mu-sample), (mu-sample))
23 | S_update = (Supdate_0 - Supdate_1)
24 | return mu_update, S_update
25 |
26 |
27 | def gsm_update(samples, vs, mu0, S0):
28 | """
29 | Returns updated mean and covariance matrix with GSM updates.
30 | For a batch, this is simply the mean of updates for individual samples.
31 |
32 | Inputs:
33 | samples: Array of samples of shape BxD where B is the batch dimension
34 | vs : Array of score functions of shape BxD corresponding to samples
35 | mu0 : Array of shape D, current estimate of the mean
36 | S0 : Array of shape DxD, current estimate of the covariance matrix
37 |
38 | Returns:
39 | mu : Array of shape D, new estimate of the mean
40 | S : Array of shape DxD, new estimate of the covariance matrix
41 | """
42 |
43 | assert len(samples.shape) == 2
44 | assert len(vs.shape) == 2
45 |
46 | B, D = samples.shape
47 | mu_update, S_update = np.zeros((B, D)), np.zeros((B, D, D))
48 | for i in range(B):
49 | mu_update[i], S_update[i] = _gsm_update_single(samples[i], vs[i], mu0, S0)
50 | mu_update = np.mean(mu_update, axis=0)
51 | S_update = np.mean(S_update, axis=0)
52 | mu = mu0 + mu_update
53 | S = S0 + S_update
54 |
55 | return mu, S
56 |
57 |
58 |
59 |
60 | class GSM:
61 | """
62 | Wrapper class for using GSM updates to fit a distribution
63 | """
64 | def __init__(self, D, lp, lp_g):
65 | """
66 | Inputs:
67 | D: (int) Dimensionality (number) of parameters
68 | lp : Function to evaluate target log-probability distribution.
69 | (Only used in monitor, not for fitting)
70 | lp_g : Function to evaluate score, i.e. the gradient of the target log-probability distribution
71 | """
72 | self.D = D
73 | self.lp = lp
74 | self.lp_g = lp_g
75 |
76 |
77 | def fit(self, key, mean=None, cov=None, batch_size=2, niter=5000, nprint=10, verbose=True, check_goodness=True, monitor=None):
78 | """
79 | Main function to fit a multivariate Gaussian distribution to the target
80 |
81 | Inputs:
82 | key: Seed for random number generator
83 | mean : Optional, initial value of the mean. Expected None or array of size D
84 | cov : Optional, initial value of the covariance matrix. Expected None or array of size DxD
85 | batch_size : Optional, int. Number of samples to match scores for at every iteration
86 | niter : Optional, int. Total number of iterations
87 | nprint : Optional, int. Number of iterations after which to print logs
88 | verbose : Optional, bool. If true, print number of iterations after nprint
89 | check_goodness : Optional, bool. Recommended. Wether to check floating point errors in covariance matrix update
90 | monitor : Optional. Function to monitor the progress and track different statistics for diagnostics.
91 | Function call should take the input tuple (iteration number, [mean, cov], lp, key, number of grad evals).
92 | Example of monitor class is provided in utils/monitors.py
93 |
94 | Returns:
95 | mu : Array of shape D, fit of the mean
96 | cov : Array of shape DxD, fit of the covariance matrix
97 | """
98 | if mean is None:
99 | mean = np.zeros(self.D)
100 | if cov is None:
101 | cov = np.identity(self.D)
102 |
103 | nevals = 1
104 |
105 | np.random.seed(key)
106 | for i in range(niter + 1):
107 | if (i%(niter//nprint) == 0) and verbose :
108 | print(f'Iteration {i} of {niter}')
109 |
110 | if monitor is not None:
111 | if (i%monitor.checkpoint) == 0:
112 | monitor(i, [mean, cov], self.lp, key, nevals=nevals)
113 | nevals = 0
114 |
115 | # Can generate samples from jax distribution (commented below), but using numpy is faster
116 | samples = np.random.multivariate_normal(mean=mean, cov=cov, size=batch_size)
117 | vs = self.lp_g(samples)
118 | mean_new, cov_new = gsm_update(samples, vs, mean, cov)
119 | nevals += batch_size
120 |
121 | is_good = self._check_goodness(cov_new)
122 | if is_good:
123 | mean, cov = mean_new, cov_new
124 | else:
125 | if verbose: print("Bad update for covariance matrix. Revert")
126 |
127 | if monitor is not None:
128 | monitor(i, [mean, cov], self.lp, key, nevals=nevals)
129 | return mean, cov
130 |
131 |
132 | def _check_goodness(self, cov):
133 | '''
134 | Internal function to check if the new covariance matrix is a valid covariance matrix.
135 | Required due to floating point errors in updating the convariance matrix directly,
136 | insteead of it's Cholesky form.
137 | '''
138 | is_good = False
139 | try:
140 | if (np.isnan(np.linalg.cholesky(cov))).any():
141 | nan_update.append(j)
142 | else:
143 | is_good = True
144 | return is_good
145 | except:
146 | return is_good
147 |
--------------------------------------------------------------------------------
/gsmvi/gsm.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax import jit, random
4 | #from numpyro.distributions import MultivariateNormal ##Needed if sampling from numpyro dist below
5 | import numpy as np
6 |
7 | @jit
8 | def _gsm_update_single(sample, v, mu0, S0):
9 | '''returns GSM update to mean and covariance matrix for a single sample
10 | '''
11 | S0v = jnp.matmul(S0, v)
12 | vSv = jnp.matmul(v, S0v)
13 | mu_v = jnp.matmul((mu0 - sample), v)
14 | rho = 0.5 * jnp.sqrt(1 + 4*(vSv + mu_v**2)) - 0.5
15 | eps0 = S0v - mu0 + sample
16 |
17 | #mu update
18 | mu_vT = jnp.outer((mu0 - sample), v)
19 | den = 1 + rho + mu_v
20 | I = jnp.eye(sample.shape[0])
21 | mu_update = 1/(1 + rho) * jnp.matmul(( I - mu_vT / den), eps0)
22 | mu = mu0 + mu_update
23 |
24 | #S update
25 | Supdate_0 = jnp.outer((mu0-sample), (mu0-sample))
26 | Supdate_1 = jnp.outer((mu-sample), (mu-sample))
27 | S_update = (Supdate_0 - Supdate_1)
28 | return mu_update, S_update
29 |
30 |
31 | @jit
32 | def gsm_update(samples, vs, mu0, S0):
33 | """
34 | Returns updated mean and covariance matrix with GSM updates.
35 | For a batch, this is simply the mean of updates for individual samples.
36 |
37 | Inputs:
38 | samples: Array of samples of shape BxD where B is the batch dimension
39 | vs : Array of score functions of shape BxD corresponding to samples
40 | mu0 : Array of shape D, current estimate of the mean
41 | S0 : Array of shape DxD, current estimate of the covariance matrix
42 |
43 | Returns:
44 | mu : Array of shape D, new estimate of the mean
45 | S : Array of shape DxD, new estimate of the covariance matrix
46 | """
47 |
48 | assert len(samples.shape) == 2
49 | assert len(vs.shape) == 2
50 |
51 | vgsm_update = jax.vmap(_gsm_update_single, in_axes=(0, 0, None, None))
52 | mu_update, S_update = vgsm_update(samples, vs, mu0, S0)
53 | mu_update = jnp.mean(mu_update, axis=0)
54 | S_update = jnp.mean(S_update, axis=0)
55 | mu = mu0 + mu_update
56 | S = S0 + S_update
57 |
58 | return mu, S
59 |
60 |
61 |
62 | class GSM:
63 | """
64 | Wrapper class for using GSM updates to fit a distribution
65 | """
66 | def __init__(self, D, lp, lp_g):
67 | """
68 | Inputs:
69 | D: (int) Dimensionality (number) of parameters
70 | lp : Function to evaluate target log-probability distribution.
71 | (Only used in monitor, not for fitting)
72 | lp_g : Function to evaluate score, i.e. the gradient of the target log-probability distribution
73 | """
74 | self.D = D
75 | self.lp = lp
76 | self.lp_g = lp_g
77 |
78 |
79 | def fit(self, key, mean=None, cov=None, batch_size=2, niter=5000, nprint=10, verbose=True, check_goodness=True, monitor=None):
80 | """
81 | Main function to fit a multivariate Gaussian distribution to the target
82 |
83 | Inputs:
84 | key: Random number generator key (jax.random.PRNGKey)
85 | mean : Optional, initial value of the mean. Expected None or array of size D
86 | cov : Optional, initial value of the covariance matrix. Expected None or array of size DxD
87 | batch_size : Optional, int. Number of samples to match scores for at every iteration
88 | niter : Optional, int. Total number of iterations
89 | nprint : Optional, int. Number of iterations after which to print logs
90 | verbose : Optional, bool. If true, print number of iterations after nprint
91 | check_goodness : Optional, bool. Recommended. Wether to check floating point errors in covariance matrix update
92 | monitor : Optional. Function to monitor the progress and track different statistics for diagnostics.
93 | Function call should take the input tuple (iteration number, [mean, cov], lp, key, number of grad evals).
94 | Example of monitor class is provided in utils/monitors.py
95 |
96 | Returns:
97 | mu : Array of shape D, fit of the mean
98 | cov : Array of shape DxD, fit of the covariance matrix
99 | """
100 | if mean is None:
101 | mean = jnp.zeros(self.D)
102 | if cov is None:
103 | cov = jnp.identity(self.D)
104 |
105 | nevals = 1
106 |
107 | for i in range(niter + 1):
108 | if (i%(niter//nprint) == 0) and verbose :
109 | print(f'Iteration {i} of {niter}')
110 |
111 | if monitor is not None:
112 | if (i%monitor.checkpoint) == 0:
113 | monitor(i, [mean, cov], self.lp, key, nevals=nevals)
114 | nevals = 0
115 |
116 | # Can generate samples from jax distribution (commented below), but using numpy is faster
117 | key, key_sample = random.split(key, 2)
118 | np.random.seed(key_sample[0])
119 | samples = np.random.multivariate_normal(mean=mean, cov=cov, size=batch_size)
120 | # samples = MultivariateNormal(loc=mean, covariance_matrix=cov).sample(key, (batch_size,))
121 | vs = self.lp_g(samples)
122 | mean_new, cov_new = gsm_update(samples, vs, mean, cov)
123 | nevals += batch_size
124 |
125 | is_good = self._check_goodness(cov_new)
126 | if is_good:
127 | mean, cov = mean_new, cov_new
128 | else:
129 | if verbose: print("Bad update for covariance matrix. Revert")
130 |
131 | if monitor is not None:
132 | monitor(i, [mean, cov], self.lp, key, nevals=nevals)
133 | return mean, cov
134 |
135 |
136 | def _check_goodness(self, cov):
137 | '''
138 | Internal function to check if the new covariance matrix is a valid covariance matrix.
139 | Required due to floating point errors in updating the convariance matrix directly,
140 | insteead of it's Cholesky form.
141 | '''
142 | is_good = False
143 | try:
144 | if (np.isnan(np.linalg.cholesky(cov))).any():
145 | nan_update.append(j)
146 | else:
147 | is_good = True
148 | return is_good
149 | except:
150 | return is_good
151 |
--------------------------------------------------------------------------------
/gsmvi/bam.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax import jit, random
4 | from jax.scipy.linalg import sqrtm as sqrtm_jsp
5 | from scipy.linalg import sqrtm as sqrtm_sp
6 | import numpy as np
7 | import scipy.sparse as spys
8 | from jax.lib import xla_bridge
9 |
10 | def compute_Q_host(U_B):
11 | U, B = U_B
12 | UU, DD, VV = spys.linalg.svds(U, k=B)
13 | return UU * np.sqrt(DD)
14 |
15 | def compute_Q(U_B):
16 | result_shape = jax.ShapeDtypeStruct((U_B[0].shape[0], U_B[1]), U_B[0].dtype)
17 | return jax.pure_callback(compute_Q_host, result_shape, U_B)
18 |
19 | def get_sqrt(M):
20 | if xla_bridge.get_backend().platform == 'gpu':
21 | result_shape = jax.ShapeDtypeStruct(M.shape, M.dtype)
22 | M_root = jax.pure_callback(lambda x:sqrtm_sp(x).astype(M.dtype), result_shape, M) # sqrt can be complex sometimes, we only want real part
23 | elif xla_bridge.get_backend().platform == 'cpu':
24 | M_root = sqrtm_jsp(M)
25 | else:
26 | print("Backend not recongnized in get_sqrt function. Should be either gpu or cpu")
27 | raise
28 | return M_root.real
29 |
30 |
31 | def bam_update(samples, vs, mu0, S0, reg):
32 | """
33 | Returns updated mean and covariance matrix with batch and match updates.
34 | For a batch, this is simply the mean of updates for individual samples.
35 |
36 | Inputs:
37 | samples: Array of samples of shape BxD where B is the batch dimension
38 | vs : Array of score functions of shape BxD corresponding to samples
39 | mu0 : Array of shape D, current estimate of the mean
40 | S0 : Array of shape DxD, current estimate of the covariance matrix
41 |
42 | Returns:
43 | mu : Array of shape D, new estimate of the mean
44 | S : Array of shape DxD, new estimate of the covariance matrix
45 | """
46 |
47 | assert len(samples.shape) == 2
48 | assert len(vs.shape) == 2
49 | B = samples.shape[0]
50 | xbar = jnp.mean(samples, axis=0)
51 | outer_map = jax.vmap(jnp.outer, in_axes=(0, 0))
52 | xdiff = samples - xbar
53 | C = jnp.mean(outer_map(xdiff, xdiff), axis=0)
54 |
55 | gbar = jnp.mean(vs, axis=0)
56 | gdiff = vs - gbar
57 | G = jnp.mean(outer_map(gdiff, gdiff), axis=0)
58 |
59 | U = reg * G + (reg)/(1+reg) * jnp.outer(gbar, gbar)
60 | V = S0 + reg * C + (reg)/(1+reg) * jnp.outer(mu0 - xbar, mu0 - xbar)
61 | I = jnp.identity(samples.shape[1])
62 |
63 | mat = I + 4 * jnp.matmul(U, V)
64 | # S = 2 * jnp.matmul(V, jnp.linalg.inv(I + sqrtm(mat).real))
65 | S = 2 * jnp.linalg.solve(I + get_sqrt(mat).T, V.T)
66 |
67 | mu = 1/(1+reg) * mu0 + reg/(1+reg) * (jnp.matmul(S, gbar) + xbar)
68 |
69 | return mu, S
70 |
71 |
72 | def bam_lowrank_update(samples, vs, mu0, S0, reg):
73 | """
74 | Returns updated mean and covariance matrix with low-rank BaM updates.
75 | For a batch, this is simply the mean of updates for individual samples.
76 |
77 | Inputs:
78 | samples: Array of samples of shape BxD where B is the batch dimension
79 | vs : Array of score functions of shape BxD corresponding to samples
80 | mu0 : Array of shape D, current estimate of the mean
81 | S0 : Array of shape DxD, current estimate of the covariance matrix
82 |
83 | Returns:
84 | mu : Array of shape D, new estimate of the mean
85 | S : Array of shape DxD, new estimate of the covariance matrix
86 | """
87 |
88 | assert len(samples.shape) == 2
89 | assert len(vs.shape) == 2
90 | B = samples.shape[0]
91 | xbar = jnp.mean(samples, axis=0)
92 | outer_map = jax.vmap(jnp.outer, in_axes=(0, 0))
93 | xdiff = samples - xbar
94 | C = jnp.mean(outer_map(xdiff, xdiff), axis=0)
95 |
96 | gbar = jnp.mean(vs, axis=0)
97 | gdiff = vs - gbar
98 | G = jnp.mean(outer_map(gdiff, gdiff), axis=0)
99 |
100 | U = reg * G + (reg)/(1+reg) * jnp.outer(gbar, gbar)
101 | V = S0 + reg * C + (reg)/(1+reg) * jnp.outer(mu0 - xbar, mu0 - xbar)
102 |
103 | # Form decomposition that is D x K
104 | Q = compute_Q((U, B))
105 | I = jnp.identity(B)
106 | VT = V.T
107 | A = VT.dot(Q)
108 | BB = 0.5*I + jnp.real(get_sqrt(A.T.dot(Q) + 0.25*I))
109 | BB = BB.dot(BB)
110 | CC = jnp.linalg.solve(BB, A.T)
111 | S = VT - A @ CC
112 | mu = 1/(1+reg) * mu0 + reg/(1+reg) * (jnp.matmul(S, gbar) + xbar)
113 |
114 | return mu, S
115 |
116 |
117 | class BaM:
118 | """
119 | Wrapper class for using BaM updates to fit a distribution
120 | """
121 | def __init__(self, D, lp, lp_g, use_lowrank=False, jit_compile=True):
122 | """
123 | Inputs:
124 | D: (int) Dimensionality (number) of parameters
125 | lp : Function to evaluate target log-probability distribution.
126 | (Only used in monitor, not for fitting)
127 | lp_g : Function to evaluate score, i.e. the gradient of the target log-probability distribution
128 | """
129 | self.D = D
130 | self.lp = lp
131 | self.lp_g = lp_g
132 | self.use_lowrank = use_lowrank
133 | if use_lowrank:
134 | print("Using lowrank update")
135 | self.jit_compile = jit_compile
136 | if not jit_compile:
137 | print("Not using jit compilation. This may take longer than it needs to.")
138 |
139 |
140 | def fit(self, key, regf, mean=None, cov=None, batch_size=2, niter=5000, nprint=10, verbose=True, check_goodness=True, monitor=None, retries=10, jitter=1e-6):
141 | """
142 | Main function to fit a multivariate Gaussian distribution to the target
143 |
144 | Inputs:
145 | key: Random number generator key (jax.random.PRNGKey)
146 | mean : Function to return regularizer value at an iteration. See Regularizers class below
147 | mean : Optional, initial value of the mean. Expected None or array of size D
148 | cov : Optional, initial value of the covariance matrix. Expected None or array of size DxD
149 | batch_size : Optional, int. Number of samples to match scores for at every iteration
150 | niter : Optional, int. Total number of iterations
151 | nprint : Optional, int. Number of iterations after which to print logs
152 | verbose : Optional, bool. If true, print number of iterations after nprint
153 | check_goodness : Optional, bool. Recommended. Wether to check floating point errors in covariance matrix update
154 | monitor : Optional. Function to monitor the progress and track different statistics for diagnostics.
155 | Function call should take the input tuple (iteration number, [mean, cov], lp, key, number of grad evals).
156 | Example of monitor class is provided in utils/monitors.py
157 |
158 | Returns:
159 | mu : Array of shape D, fit of the mean
160 | cov : Array of shape DxD, fit of the covariance matrix
161 | """
162 |
163 | if mean is None:
164 | mean = jnp.zeros(self.D)
165 | if cov is None:
166 | cov = jnp.identity(self.D)
167 |
168 | nevals = 1
169 |
170 | if self.use_lowrank:
171 | update_function = bam_lowrank_update
172 | else:
173 | update_function = bam_update
174 | if self.jit_compile:
175 | update_function = jit(update_function)
176 |
177 | if nprint > niter: nprint = niter
178 | for i in range(niter + 1):
179 | if (i%(niter//nprint) == 0) and verbose :
180 | print(f'Iteration {i} of {niter}')
181 |
182 | if monitor is not None:
183 | if (i%monitor.checkpoint) == 0:
184 | monitor(i, [mean, cov], self.lp, key, nevals=nevals)
185 | nevals = 0
186 |
187 | # Can generate samples from jax distribution (commented below), but using numpy is faster
188 | j = 0
189 | while True: # Sometimes run crashes due to a bad sample. Avoid that by re-trying.
190 | try:
191 | key, key_sample = random.split(key, 2)
192 | np.random.seed(key_sample[0])
193 | samples = np.random.multivariate_normal(mean=mean, cov=cov, size=batch_size)
194 | vs = self.lp_g(samples)
195 | nevals += batch_size
196 | reg = regf(i)
197 | mean_new, cov_new = update_function(samples, vs, mean, cov, reg)
198 | cov_new += np.eye(self.D) * jitter # jitter covariance matrix
199 | cov_new = (cov_new + cov_new.T)/2.
200 | break
201 | except Exception as e:
202 | if j < retries :
203 | j += 1
204 | print(f"Failed with exception {e}")
205 | print(f"Trying again {j} of {retries}")
206 | else : raise e
207 |
208 | is_good = self._check_goodness(cov_new)
209 | if is_good:
210 | mean, cov = mean_new, cov_new
211 | else:
212 | if verbose: print("Bad update for covariance matrix. Revert")
213 |
214 | if monitor is not None:
215 | monitor(i, [mean, cov], self.lp, key, nevals=nevals)
216 | return mean, cov
217 |
218 |
219 | def _check_goodness(self, cov):
220 | '''
221 | Internal function to check if the new covariance matrix is a valid covariance matrix.
222 | Required due to floating point errors in updating the convariance matrix directly,
223 | insteead of it's Cholesky form.
224 | '''
225 | is_good = False
226 | try:
227 | if (np.isnan(np.linalg.cholesky(cov))).any():
228 | nan_update.append(j)
229 | else:
230 | is_good = True
231 | return is_good
232 | except:
233 | return is_good
234 |
235 |
236 |
237 | class Regularizers():
238 | """
239 | Class for regularizers used in BaM
240 | """
241 |
242 | def __init__(self):
243 |
244 | self.counter = 0
245 |
246 | def reset(self):
247 |
248 | self.counter = 0
249 |
250 |
251 | def constant(self, reg0):
252 |
253 | def reg_iter(iteration):
254 | self.counter +=1
255 | return reg0
256 | return reg_iter
257 |
258 |
259 | def linear(self, reg0):
260 |
261 | def reg_iter(iteration):
262 | self.counter += 1
263 | return reg0/self.counter
264 |
265 | return reg_iter
266 |
267 |
268 | def custom(self, func):
269 |
270 | def reg_iter(iteration):
271 | self.counter += 1
272 | return func(self.counter)
273 |
274 | return reg_iter
275 |
--------------------------------------------------------------------------------