├── overview.png ├── BASQ ├── _metric.py ├── experiment │ ├── uncertainty.py │ └── ecm.py ├── _posterior.py ├── _utils.py ├── _basq.py ├── _vbq.py ├── _gaussian_calc.py ├── _rchq.py ├── _scale_mmlt_wsabi.py ├── _quadrature.py ├── _gp.py ├── _wsabi.py ├── _mmlt_wsabi.py ├── _parameters.py ├── _acquisition_function.py ├── _sampler.py └── _lbfgs.py ├── LICENSE └── README.md /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Battery-Intelligence-Lab/BayesianModelSelection/HEAD/overview.png -------------------------------------------------------------------------------- /BASQ/_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class KLdivergence: 5 | def __init__(self, prior, test_data, Z_true, device, true_function): 6 | """ 7 | Args: 8 | - prior: torch.distributions, prior distribution 9 | - test_data: torch.tensor, samples for evaluation, sampled from prior 10 | - Z_true: float, the true evidence. 11 | - device: torch.device, cpu or cuda 12 | - true_function: function of y = function(x), true likelihood funciton 13 | """ 14 | self.prior = prior 15 | self.test_data = test_data 16 | self.Z_true = Z_true 17 | self.device = device 18 | self.true_function = true_function 19 | 20 | def __call__(self, basq_model): 21 | """ 22 | Args: 23 | - basq_mode: gpytorch.models, the trained BQ model 24 | - EZy: float, the estimated evidence 25 | 26 | Returns: 27 | - KL: torch.tensor, the Kullback-Leibler divergence between true posterior and estimated posterior 28 | """ 29 | KL = torch.zeros(len(self.test_data)).to(self.device) 30 | q = torch.squeeze( 31 | self.prior.log_prob(self.test_data).exp() * self.true_function(self.test_data).exp() / self.Z_true 32 | ) 33 | p = basq_model.posterior.joint_posterior(self.test_data) 34 | KL[p > 0] = p[p > 0] * torch.log(p[p > 0] / q[p > 0]) 35 | _KL = KL[KL.isinf().logical_not()] 36 | return torch.abs(_KL.sum() * len(self.test_data) / len(_KL)) 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, Battery Intelligence Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BayesianModelSelection 2 | This repository contains the python code that was presented for the IFAC. 3 | 4 | Adachi, M., Kuhn, Y., Horstmann, B., Osborne, M. A., Howey, D. A. 5 | Bayesian Model Selection of Lithium-Ion Battery Models via Bayesian Quadrature, IFAC 2023 [link](https://doi.org/10.1016/j.ifacol.2023.10.1073) 6 | 7 | This work is based on the BASQ [repository](https://github.com/ma921/BASQ) 8 | 9 | ![Animate](./overview.png) 10 | 11 | 12 | ## News 13 | Recently we have published a new method that achieves faster convergence. 14 | https://github.com/ma921/SOBER
15 | Try it out the tutorial 05 for comparing. 16 | 17 | ## Features 18 | - fast Bayesian inference via Bayesian quadrature 19 | - Simultaneous inference of Bayesian model evidence and posterior 20 | - GPU acceleration 21 | - Canonical equivalent circuit model (ECM) 22 | - Statistical analysis computation of the ECM 23 | 24 | ## Requirements 25 | - PyTorch 26 | - GPyTorch 27 | - BoTorch 28 | - functorch 29 | 30 | ## Getting started 31 | Open "ECM_model_selection.ipynb". 32 | This will give you a step-by-step introduction. 33 | 34 | ## Cite as 35 | 36 | Please cite this work as 37 | ``` 38 | @article{adachi2023bayesian, 39 | title={Bayesian model selection of lithium-ion battery models via {B}ayesian quadrature}, 40 | author={Adachi, Masaki and Kuhn, Yannick and Horstmann, Birger and Latz, Arnulf and Osborne, Michael A and Howey, David A}, 41 | journal={IFAC-PapersOnLine}, 42 | volume={56}, 43 | number={2}, 44 | pages={10521--10526}, 45 | year={2023}, 46 | doi={https://doi.org/10.1016/j.ifacol.2023.10.1073}, 47 | publisher={Elsevier} 48 | } 49 | ``` 50 | Also please consider to cite this work as well. 51 | ``` 52 | @article{adachi2022fast, 53 | title={Fast {B}ayesian inference with batch {B}ayesian quadrature via kernel recombination}, 54 | author={Adachi, Masaki and Hayakawa, Satoshi and J{\o}rgensen, Martin and Oberhauser, Harald and Osborne, Michael A}, 55 | journal={Advances in Neural Information Processing Systems}, 56 | volume={35}, 57 | doi={https://doi.org/10.48550/arXiv.2206.04734}, 58 | year={2022} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /BASQ/experiment/uncertainty.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import scipy.stats as stats 4 | 5 | def calc_params(params_true): 6 | Rt = torch.exp(params_true[0]) 7 | _r1 = torch.exp(-torch.exp(params_true[1])) 8 | t1 = params_true[2] 9 | _r2 = torch.exp(-torch.exp(params_true[3])) 10 | t2 = params_true[4] 11 | sigma_noise = torch.exp(-torch.exp(params_true[5])) 12 | r0 = 1 - _r1 - _r2 13 | 14 | w1 = _r1 / (1 - r0) 15 | w2 = _r2 / (1 - r0) 16 | sigma_omega = torch.tensor(6.07285737991333) 17 | ba = 2*torch.tensor(20.723264694213867) 18 | 19 | return w1, w2, t1, t2, Rt, r0, sigma_noise, ba, sigma_omega 20 | 21 | def symmetric_KL(w1, w2, tau1, tau2, sigma_omega): 22 | Nis = 10000000 23 | 24 | if tau1 > tau2: 25 | temp = copy.deepcopy(tau1) 26 | tau1 = copy.deepcopy(tau2) 27 | tau2 = copy.deepcopy(temp) 28 | temp = copy.deepcopy(w1) 29 | w1 = copy.deepcopy(w2) 30 | w2 = copy.deepcopy(temp) 31 | 32 | samples = torch.from_numpy(stats.hypsecant.rvs(size=Nis)) 33 | Nmid = int(Nis / 2) 34 | N1 = int(Nmid * w1) 35 | N2 = int(Nis - Nmid - N1) 36 | delta12 = sigma_omega * (w2*tau2 - w1*tau1) 37 | 38 | samples_N1 = samples[Nmid:Nmid+N1]/w1 - sigma_omega * tau1 39 | samples_N2 = samples[Nmid+N1:]/w2 - sigma_omega * tau2 40 | samples_Nmid = samples[:Nmid]*2 - 0.5 * delta12 - sigma_omega * w1 * tau1 41 | samples_g = torch.cat([samples_Nmid, samples_N1, samples_N2]) 42 | 43 | # importance sampling 44 | P1 = w1/torch.pi/torch.cosh(w1*(samples_g + sigma_omega * tau1)) 45 | P2 = w2/torch.pi/torch.cosh(w2*(samples_g + sigma_omega * tau2)) 46 | g = 1/4*( 47 | w1/torch.pi/torch.cosh(w1*(samples_g + sigma_omega * tau1)) 48 | + w2/torch.pi/torch.cosh(w2*(samples_g + sigma_omega * tau2)) 49 | ) + 1/4/torch.pi/torch.cosh( 50 | 0.5*(samples_g + sigma_omega * w1 * tau1 + 0.5*delta12) 51 | ) 52 | KL12 = ((P1/P2).log()*P1/g).mean() 53 | KL21 = ((P2/P1).log()*P2/g).mean() 54 | Dkl = KL12 + KL21 55 | return Dkl 56 | 57 | def signal_to_noise(w1, w2, t1, t2, Rt, r0, sigma_noise, ba, sigma_omega): 58 | W12 = w1 * w2 59 | Wsum = w1**2 + w2**2 60 | Delta_mu_12 = sigma_omega * torch.abs(t1 - t2) 61 | 62 | EZ = Rt * torch.pi * (1 - r0) / ba 63 | A = Wsum + 2 * W12 * Delta_mu_12 / torch.sinh(Delta_mu_12) 64 | EZ2 = (Rt**2) * ((1-r0)**2) / ba * A 65 | SNR = torch.log((EZ2 - EZ**2) / sigma_noise) 66 | return SNR 67 | 68 | def calc_uncertainties(params_true): 69 | w1, w2, t1, t2, Rt, r0, sigma_noise, ba, sigma_omega = calc_params(params_true) 70 | KL = symmetric_KL(w1, w2, t1, t2, sigma_omega) 71 | SNR = signal_to_noise(w1, w2, t1, t2, Rt, r0, sigma_noise, ba, sigma_omega) 72 | rho = SNR * KL 73 | return torch.tensor([KL, SNR, rho]).float() -------------------------------------------------------------------------------- /BASQ/_posterior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Posterior: 5 | def __init__(self, bq_model, prior, gp, predict_mean, kq, sampler, sampler_type): 6 | """ 7 | Args: 8 | - bq_model: string, ["mmlt", "wsabi", "vbq"] 9 | - prior: torch.distributions, prior distribution. 10 | - gp: class, Gaussian process module 11 | - predict_mean: function of mean = function(x), the function that returns the predictive mean at given x 12 | - kq: class, Kernel Quadrature module 13 | - sampler: class, sampler module 14 | - sampler_type: string, type of sampler 15 | """ 16 | self.bq_model = bq_model 17 | self.prior = prior 18 | self.gp = gp 19 | self.predict_mean = predict_mean 20 | self.kq = kq 21 | self.sampler = sampler 22 | self.sampler_type = sampler_type 23 | 24 | def check_evidence(self): 25 | if not hasattr(self, "EZy"): 26 | if self.bq_model == "mmlt": 27 | if not hasattr(self.kq, "logEZy"): 28 | logEZy, _ = self.kq.quadrature() 29 | else: 30 | logEZy = self.kq.logEZy 31 | self.EZy = (logEZy - self.gp.beta).exp() 32 | else: 33 | if not hasattr(self.kq, "EZy"): 34 | self.EZy, _ = self.kq.quadrature() 35 | else: 36 | self.EZy = self.kq.EZy 37 | 38 | def joint_posterior(self, x): 39 | """ 40 | Args: 41 | - x: torch.tensor, inputs. torch.Size(n_data, n_dims) 42 | 43 | Returns: 44 | - torch.tensor, the posterior of given x 45 | """ 46 | self.check_evidence() 47 | return self.predict_mean(x) * self.prior.log_prob(x).exp() / self.EZy 48 | 49 | def sample(self, n): 50 | """ 51 | Args: 52 | - n: int, number of samples to be generated 53 | 54 | Returns: 55 | - samples: torch.tensor, the samples drawn from posterior 56 | """ 57 | if self.bq_model == "mmlt" and self.sampler_type == "uncertainty": 58 | self.sampler.update(self.gp.model) 59 | n_super = int(self.sampler.ratio_super * n) 60 | supersample = self.sampler.sampling(n_super) 61 | 62 | mu_log = self.predict_mean(supersample).detach().abs().log() 63 | prior_log = self.prior.log_prob(supersample) 64 | sampler_log = torch.nan_to_num(self.sampler.joint_pdf(supersample)).log() 65 | weights = torch.exp( 66 | mu_log + prior_log - sampler_log 67 | ) 68 | samples = self.sampler.SIR(supersample, weights, n) 69 | return samples 70 | else: 71 | raise Exception("this feature has not been implemented yet.") 72 | 73 | def MAP_estimation(self, n): 74 | """ 75 | Args: 76 | - n: int, number of seeds 77 | 78 | Returns: 79 | - X_map: torch.tensor, maximum a posteori sample 80 | """ 81 | seeds = self.sample(n) 82 | ypred = self.joint_posterior(seeds) 83 | idx_max = ypred.argmax() 84 | MAP = ypred[idx_max].item() 85 | X_map = seeds[idx_max] 86 | print("PDF of posterior at MAP: " + str(MAP)) 87 | return X_map 88 | -------------------------------------------------------------------------------- /BASQ/_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | from torch.distributions.multivariate_normal import MultivariateNormal 4 | 5 | 6 | class Utils: 7 | def __init__(self, device): 8 | """ 9 | Args: 10 | - device: torch.device, cpu or cuda 11 | """ 12 | self.eps = -torch.sqrt(torch.tensor(torch.finfo().max)).item() 13 | self.gpu_lim = int(5e5) 14 | self.max_iter = 10 15 | self.device = device 16 | 17 | def remove_anomalies(self, y): 18 | """ 19 | Args: 20 | - y: torch.tensor, observations 21 | 22 | Returns: 23 | - y: torch.tensor, observations whose anomalies have been removed. 24 | """ 25 | y[y.isnan()] = self.eps 26 | y[y.isinf()] = self.eps 27 | y[y < self.eps] = self.eps 28 | return y 29 | 30 | def remove_anomalies_uniform(self, X, uni_min, uni_max): 31 | """ 32 | Args: 33 | - X: torch.tensor, inputs 34 | - uni_min: torch.tensor, the minimum limit values of uniform distribution 35 | - uni_max: torch.tensor, the maximum limit values of uniform distribution 36 | 37 | Returns: 38 | - idx: bool, indices where the inputs X do not exceed the min-max limits 39 | """ 40 | logic = torch.sum(torch.stack([torch.logical_or( 41 | X[:, i] < uni_min[i], 42 | X[:, i] > uni_max[i], 43 | ) for i in range(X.size(1))]), axis=0) 44 | return (logic == 0) 45 | 46 | def is_psd(self, mat): 47 | """ 48 | Args: 49 | - mat: torch.tensor, symmetric matrix 50 | 51 | Returns: 52 | - flag: bool, flag to judge whether or not the given matrix is positive semi-definite 53 | """ 54 | return bool((mat == mat.T).all() and (torch.eig(mat)[0][:, 0] >= 0).all()) 55 | 56 | def make_cov_psd(self, cov): 57 | """ 58 | Args: 59 | - cov: torch.tensor, covariance matrix of multivariate normal distribution 60 | 61 | Returns: 62 | - cov: torch.tensor, covariance matrix of multivariate normal distribution 63 | """ 64 | if self.is_psd(cov): 65 | return cov 66 | else: 67 | warnings.warn("Estimated covariance matrix was not positive semi-definite. Conveting...") 68 | cov = torch.nan_to_num(cov) 69 | cov = torch.sqrt(cov * cov.T) 70 | if not self.is_psd(cov): 71 | n_dim = cov.size(0) 72 | r_increment = 2 73 | jitter = torch.ones(n_dim).to(self.device) * 1e-5 74 | n_iter = 0 75 | while not self.is_psd(cov): 76 | cov[range(n_dim), range(n_dim)] += jitter 77 | jitter *= r_increment 78 | n_iter += 1 79 | if n_iter > self.max_iter: 80 | cov = cov.diag().diag() 81 | break 82 | return cov 83 | 84 | def safe_mvn_register(self, mu, cov): 85 | """ 86 | Args: 87 | - mu: torch.tensor, mean vector of multivariate normal distribution 88 | - cov: torch.tensor, covariance matrix of multivariate normal distribution 89 | 90 | Returns: 91 | - mvn: torch.distributions, function of multivariate normal distribution 92 | """ 93 | cov = self.make_cov_psd(cov) 94 | return MultivariateNormal(mu, cov) 95 | 96 | def safe_mvn_prob(self, mu, cov, X): 97 | """ 98 | Args: 99 | - mu: torch.tensor, mean vector of multivariate normal distribution 100 | - cov: torch.tensor, covariance matrix of multivariate normal distribution 101 | - X: torch.tensor, the locations that we wish to calculate the probability density values 102 | 103 | Returns: 104 | - pdf: torch.tensor, the probability density values at given locations X. 105 | """ 106 | mvn = self.safe_mvn_register(mu, cov) 107 | if X.size(0) > self.gpu_lim: 108 | warnings.warn("The matrix size exceeds the GPU limit. Splitting.") 109 | n_split = torch.tensor(X.size(0) / self.gpu_lim).ceil().long() 110 | _X = torch.tensor_split(X, n_split) 111 | Npdfs = torch.cat( 112 | [ 113 | mvn.log_prob(_X[i]).exp() 114 | for i in range(n_split) 115 | ] 116 | ) 117 | else: 118 | Npdfs = mvn.log_prob(X).exp() 119 | return Npdfs 120 | -------------------------------------------------------------------------------- /BASQ/_basq.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from ._rchq import recombination 4 | from ._parameters import Parameters 5 | 6 | 7 | class BASQ(Parameters): 8 | def __init__(self, Xobs, Yobs, prior, true_likelihood, device): 9 | """ 10 | Goal: Estimate both evidence and posterior in one go with minimal queries. 11 | 12 | Args: 13 | - Xobs; torch.tensor, X samples, X belongs to prior measure. 14 | - Yobs; torch.tensor, Y observations, Y = true_likelihood(X). 15 | - prior; torch.distributions, prior distribution. 16 | - true_likelihood; function of y = function(x), true likelihood to be estimated. 17 | - device; torch.device, device, cpu or cuda 18 | 19 | Results: 20 | - evidence (a.k.a. marginal likelihood); 21 | EZy, VarZy = self.kq.quadrature() 22 | EZy; the mean of evidence 23 | VarZy; the variance of evidence 24 | - posterior; self.joint_posterior(x, EZy) 25 | """ 26 | super().__init__(Xobs, Yobs, prior, true_likelihood, device) 27 | 28 | def quadratures(self): 29 | """ 30 | Calculate two additional quadratures. 31 | The following quadratures are applicable only for WSABI-BQ. 32 | - Prior maximisation; the prior distribution is optimised to maximise the evidence 33 | - Uniform transformation; the prior distribution is transformed into uniform 34 | distrubution via impotance sampling. 35 | 36 | Args: 37 | - EZy_prior: float, the mean of the evidence when the prior is optimised to maximise the evidence 38 | - VarZy_prior: float, the variance of the evidence when the prior is optimised to maximise the evidence 39 | - EZy_uni float, the mean of the evidence when the prior is transformed into uniform distribution 40 | - VarZy_uni: float, the variance of the evidence when the prior is transformed into uniform distribution 41 | """ 42 | mvn_max = self.unimodal_approx() 43 | EZy_prior, VarZy_prior = self.kq.prior_max(mvn_max) 44 | model_IS, uni_sampler = self.uniform_trans(mvn_max) 45 | EZy_uni, VarZy_uni = self.kq.uniform_trans(model_IS, uni_sampler) 46 | return EZy_prior, VarZy_prior, EZy_uni, VarZy_uni 47 | 48 | def run_rchq(self, pts_nys, pts_rec, w_IS, kernel): 49 | """ 50 | Args: 51 | - pts_nys: torch.tensor, subsamples for low-rank approximation via Nyström method 52 | - pts_rec: torch.tensor, subsamples for empirical measure of kernel recomnbination 53 | - w_IS: torch.tensor, weights for importance sampling if pts_rec is not sampled from the prior 54 | - kernel: function of covariance_matrix = function(X, Y). Positive semi-definite Gram matrix (a.k.a. kernel) 55 | 56 | Returns: 57 | - x: torch.tensor, the sparcified samples from pts_rec. The number of samples are determined by self.batch_size 58 | - w: torch.tensor, the positive weights for kernel quadrature as discretised summation. 59 | """ 60 | idx, w = recombination( 61 | pts_rec, 62 | pts_nys, 63 | self.batch_size, 64 | kernel, 65 | self.device, 66 | init_weights=w_IS, 67 | ) 68 | x = pts_rec[idx] 69 | return x, w 70 | 71 | def run_basq(self): 72 | if self.sampler_type == "uncertainty": 73 | self.sampler.update(self.gp.model) 74 | pts_nys, pts_rec, w_IS = self.sampler(self.n_rec) 75 | X, _ = self.run_rchq(pts_nys, pts_rec, w_IS, self.kernel) 76 | Y = self.true_likelihood(X) 77 | self.update(X, Y) 78 | 79 | def run(self, n_batch): 80 | """ 81 | Args: 82 | - n_batch: int, number of iteration. The total query is n_batch * self.batch_size 83 | 84 | Returns: 85 | - results: torch.tensor, [overhead, EZy, VarZy] 86 | """ 87 | results = [] 88 | overhead = 0 89 | for _ in range(n_batch): 90 | s = time.time() 91 | self.run_basq() 92 | _overhead = time.time() - s 93 | if self.show_progress: 94 | EZy, VarZy = self.kq.quadrature() 95 | results.append([_overhead, EZy, VarZy]) 96 | else: 97 | overhead += _overhead 98 | if not self.show_progress: 99 | EZy, VarZy = self.kq.quadrature() 100 | results.append([overhead, EZy, VarZy]) 101 | 102 | if self.bq_model == "wsabi": 103 | self.gp.memorise_parameters() 104 | EZy_prior, VarZy_prior, EZy_uni, VarZy_uni = self.quadratures() 105 | self.gp.remind_parameters() 106 | self.retrain() 107 | return torch.tensor(results) 108 | -------------------------------------------------------------------------------- /BASQ/_vbq.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from ._gp import update_gp, predict, predictive_covariance 4 | from ._utils import Utils 5 | 6 | 7 | class VanillaGP: 8 | def __init__( 9 | self, 10 | Xobs, 11 | Yobs, 12 | gp_kernel, 13 | device, 14 | lik=1e-10, 15 | training_iter=10000, 16 | thresh=0.01, 17 | lr=0.1, 18 | rng=10, 19 | train_lik=False, 20 | optimiser="L-BFGS-B", 21 | ): 22 | """ 23 | Args: 24 | - Xobs: torch.tensor, X samples, X belongs to prior measure. 25 | - Yobs: torch.tensor, Y observations, Y = true_likelihood(X). 26 | - gp_kernel: gpytorch.kernels, GP kernel function 27 | - device: torch.device, device, cpu or cuda 28 | - lik: float, the initial value of GP likelihood noise variance 29 | - train_iter: int, the maximum iteration for GP hyperparameter training. 30 | - thresh: float, the threshold as a stopping criterion of GP hyperparameter training. 31 | - lr: float, the learning rate of Adam optimiser 32 | - rng: int, tne range coefficient of GP likelihood noise variance 33 | - train_like: bool, flag whether or not to update GP likelihood noise variance 34 | """ 35 | self.gp_kernel = gp_kernel 36 | self.device = device 37 | self.lik = lik 38 | self.training_iter = training_iter 39 | self.thresh = thresh 40 | self.lr = lr 41 | self.rng = rng 42 | self.train_lik = train_lik 43 | self.optimiser = optimiser 44 | 45 | self.jitter = 1e-6 46 | self.Y_unwarp = copy.deepcopy(Yobs) 47 | self.utils = Utils(device) 48 | 49 | self.model = update_gp( 50 | Xobs, 51 | Yobs, 52 | gp_kernel, 53 | self.device, 54 | lik=self.lik, 55 | training_iter=self.training_iter, 56 | thresh=self.thresh, 57 | lr=self.lr, 58 | rng=self.rng, 59 | train_lik=self.train_lik, 60 | optimiser=self.optimiser, 61 | ) 62 | 63 | def cat_observations(self, X, Y): 64 | """ 65 | Args: 66 | - X: torch.tensor, X samples to be added to the existing data Xobs 67 | - Y: torch.tensor, Y observations to be added to the existing data Yobs 68 | 69 | Returns: 70 | - Xall: torch.tensor, X samples that contains all samples 71 | - Yall: torch.tensor, Y observations that contains all observations 72 | """ 73 | Xobs = self.model.train_inputs[0] 74 | Yobs = self.model.train_targets 75 | if len(self.model.train_targets.shape) == 0: 76 | Yobs = Yobs.unsqueeze(0) 77 | Xall = torch.cat([Xobs, X]) 78 | Yall = torch.cat([Yobs, Y]) 79 | return Xall, Yall 80 | 81 | def update_gp(self, X, Y): 82 | """ 83 | Args: 84 | - X: torch.tensor, X samples to be added to the existing data Xobs 85 | - Y: torch.tensor, Y observations to be added to the existing data Yobs 86 | """ 87 | Xall, Yall = self.cat_observations(X, Y) 88 | self.model = update_gp( 89 | Xall, 90 | Yall, 91 | self.gp_kernel, 92 | self.device, 93 | lik=self.lik, 94 | training_iter=self.training_iter, 95 | thresh=self.thresh, 96 | lr=self.lr, 97 | rng=self.rng, 98 | train_lik=self.train_lik, 99 | optimiser=self.optimiser, 100 | ) 101 | 102 | def retrain_gp(self): 103 | Xobs = self.model.train_inputs[0] 104 | Yobs = self.model.train_targets 105 | self.model = update_gp( 106 | Xobs, 107 | Yobs, 108 | self.gp_kernel, 109 | self.device, 110 | lik=self.lik, 111 | training_iter=self.training_iter, 112 | thresh=self.thresh, 113 | lr=self.lr, 114 | rng=self.rng, 115 | train_lik=self.train_lik, 116 | optimiser=self.optimiser, 117 | ) 118 | 119 | def predictive_kernel(self, x, y): 120 | """ 121 | Args: 122 | - x: torch.tensor, x locations to be predicted 123 | - y: torch.tensor, y locations to be predicted 124 | 125 | Args: 126 | - CLy: torch.tensor, the positive semi-definite Gram matrix of predictive variance 127 | """ 128 | return predictive_covariance(x, y, self.model) 129 | 130 | def predict(self, x): 131 | """ 132 | Args: 133 | - x: torch.tensor, x locations to be predicted 134 | 135 | Returns: 136 | - mu: torch.tensor, predictive mean at given locations x. 137 | - var: torch.tensor, predictive variance at given locations x. 138 | """ 139 | mu, var = predict(x, self.model) 140 | return mu, var 141 | 142 | def predict_mean(self, x): 143 | """ 144 | Args: 145 | - x: torch.tensor, x locations to be predicted 146 | 147 | Returns: 148 | - mu: torch.tensor, predictive mean at given locations x. 149 | """ 150 | mu, _ = predict(x, self.model) 151 | return mu 152 | -------------------------------------------------------------------------------- /BASQ/_gaussian_calc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from ._utils import Utils 4 | from torch.distributions.uniform import Uniform 5 | from torch.distributions.multivariate_normal import MultivariateNormal 6 | 7 | 8 | class GaussianCalc: 9 | def __init__(self, prior, device): 10 | """ 11 | Args: 12 | - prior; torch.distributions, prior distribution. 13 | - device; torch.device, device, cpu or cuda 14 | """ 15 | self.prior = prior 16 | self.device = device 17 | self.utils = Utils(device) 18 | 19 | def get_cache(self, model): 20 | """ 21 | Observed dataset = (Xobs, Yobs) 22 | woodbury_vector = K(Xobs, Xobs)^(-1) @ Yobs 23 | woodbury_inv = K(Xobs, Xobs)^(-1) 24 | S @ S.T = woodbury_inv 25 | 26 | Args: 27 | - model: gpytorch.models, function of GP model, typically self.wsabi.model in _basq.py 28 | 29 | Returns: 30 | - woodbury_vector: torch.tensor, Woodbury vector, K(Xobs, Xobs)^(-1) @ Yobs 31 | - woodbury_inv: torch.tensor, the inverse of Gram matrix K(Xobs, Xobs)^(-1) 32 | """ 33 | try: 34 | woodbury_vector = model.prediction_strategy.mean_cache 35 | S = model.prediction_strategy.covar_cache 36 | except AttributeError: 37 | model.eval() 38 | mean = self.prior.loc.view(-1).unsqueeze(0) 39 | model(mean) 40 | woodbury_vector = model.prediction_strategy.mean_cache 41 | S = model.prediction_strategy.covar_cache 42 | woodbury_inv = S @ S.T 43 | return woodbury_vector, woodbury_inv 44 | 45 | def parameters_extraction(self, model): 46 | self.Xobs = copy.deepcopy(model.train_inputs[0]) 47 | self.n_data, self.n_dims = self.Xobs.size() 48 | self.woodbury_vector, self.woodbury_inv = self.get_cache(model) 49 | self.outputscale = copy.deepcopy(model.covar_module.outputscale.detach()) 50 | self.lengthscale = copy.deepcopy(model.covar_module.base_kernel.lengthscale.detach()) 51 | self.W = torch.eye(self.Xobs.size(1)).to(self.device) * self.lengthscale ** 2 52 | self.v = self.outputscale * torch.sqrt(torch.linalg.det(2 * torch.pi * self.W)) 53 | 54 | def calc_expGMM(self, a): 55 | Npdfs = MultivariateNormal( 56 | a, 57 | self.W / 2, 58 | ).log_prob(self.X_ij_plus_half).exp() 59 | return torch.exp(-1 * Npdfs @ self.Wij_flat) 60 | 61 | def calc_Taylor(self, mus): 62 | return torch.stack([self.calc_expGMM(mu) for mu in mus]) 63 | 64 | def unimodal_approximation(self, model, alpha): 65 | """ 66 | approximate the GP-modelled likelihood by a unimodal multivariate normal distribution. 67 | https://math.stackexchange.com/questions/195911/calculation-of-the-covariance-of-gaussian-mixtures 68 | 69 | Args: 70 | - model: gpytorch.models, function of GP model, typically self.wsabi.model in _basq.py 71 | - alpha: torch.tensor, the alpha hyperparameter in WSABI-BQ modelling, ell = alpha + 0.5 ell^2 72 | 73 | Returns: 74 | - mvn_pi_max: torch.distributions, mutlivariate normal distribution of optimised prior 75 | """ 76 | self.parameters_extraction(model) 77 | 78 | x = (self.Xobs.unsqueeze(1) - self.Xobs.unsqueeze(0)).reshape(self.n_data**2, self.n_dims) 79 | Npdfs = self.utils.safe_mvn_prob( 80 | torch.zeros(self.n_dims).to(self.device), 81 | 2 * self.W, 82 | x, 83 | ).reshape(self.n_data, self.n_data) 84 | 85 | _w_m = 0.5 * (self.v**2) * (self.woodbury_vector.unsqueeze(1) * self.woodbury_vector.unsqueeze(0)) * Npdfs 86 | w_m = _w_m / _w_m.sum() 87 | 88 | mu_pi_max = alpha + (w_m.unsqueeze(2) * (self.Xobs.unsqueeze(1) + self.Xobs.unsqueeze(0)) / 2).sum(axis=0).sum(axis=0) 89 | Xij2 = ((self.Xobs.unsqueeze(1) + self.Xobs.unsqueeze(0)) / 2).reshape(self.n_data ** 2, self.n_dims) - mu_pi_max 90 | W_m = w_m.reshape(self.n_data**2, 1) 91 | cov_pi_max = (W_m.unsqueeze(1) * Xij2.unsqueeze(2) @ Xij2.unsqueeze(1)).sum(axis=0) + self.W / 2 92 | mvn_pi_max = self.utils.safe_mvn_register(mu_pi_max, cov_pi_max) 93 | return mvn_pi_max 94 | 95 | def uniform_transformation(self, model, Y_unwarp): 96 | """ 97 | Args: 98 | - model: gpytorch.models, function of GP model, typically self.wsabi.model in _basq.py 99 | - Y_unwarp: torch.tensor, the raw observations without WSABI warping. 100 | 101 | Returns: 102 | - Xobs_uni: torch.tensor, the inputs transformed into uniform prior distribution 103 | - Yobs_uni: torch.tensor, the observation transformed into uniform prior distribution 104 | - uni_sampler: function of samples = function(n_samples), a uniform distribution sampler 105 | - uni_logpdf: function of logpdf = function(x), a log probability density function of uniform distribution 106 | """ 107 | self.parameters_extraction(model) 108 | uni_min = self.Xobs.min(0)[0] 109 | uni_max = self.Xobs.max(0)[0] 110 | 111 | uni_sampler = lambda N: torch.stack([ 112 | Uniform(uni_min[i], uni_max[i]).sample(torch.Size([N])) for i in range(self.n_dims) 113 | ]).T 114 | uni_logpdf = lambda X: torch.ones(X.size(0)).to(self.device) * torch.sum(-torch.log(uni_max - uni_min)) 115 | 116 | Yobs_uni = self.utils.remove_anomalies(Y_unwarp) 117 | idx = self.utils.remove_anomalies_uniform(self.Xobs, uni_min, uni_max) 118 | Xobs_uni = self.Xobs[idx] 119 | Yobs_uni = Yobs_uni[idx] 120 | return Xobs_uni, Yobs_uni, uni_sampler, uni_logpdf 121 | -------------------------------------------------------------------------------- /BASQ/_rchq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def recombination( 5 | pts_rec, # random samples for recombination 6 | pts_nys, # number of samples used for approximating kernel with Nystrom method 7 | num_pts, # number of samples finally returned 8 | kernel, # kernel 9 | device, # device 10 | init_weights=0, # initial weights of the sample for recombination 11 | ): 12 | """ 13 | Args: 14 | - pts_nys: torch.tensor, subsamples for low-rank approximation via Nyström method 15 | - pts_rec: torch.tensor, subsamples for empirical measure of kernel recomnbination 16 | - num_pts: int, number of samples finally returned. In BASQ context, this is equivalent to batch size 17 | - kernel: function of covariance_matrix = function(X, Y). Positive semi-definite Gram matrix (a.k.a. kernel) 18 | - device: torch.device, cpu or cuda 19 | - init_weights: torch.tensor, weights for importance sampling if pts_rec is not sampled from the prior 20 | 21 | Returns: 22 | - x: torch.tensor, the sparcified samples from pts_rec. The number of samples are determined by self.batch_size 23 | - w: torch.tensor, the positive weights for kernel quadrature as discretised summation. 24 | """ 25 | return rc_kernel_svd(pts_rec, pts_nys, num_pts, kernel, device, mu=init_weights) 26 | 27 | 28 | def ker_svd_sparsify(pt, s, kernel, device): 29 | _U, S, _ = torch.svd_lowrank(kernel(pt, pt), q=s) 30 | U = -1 * _U.T # Hermitian 31 | return S, U 32 | 33 | 34 | def rc_kernel_svd(samp, pt, s, kernel, device, mu=0, use_obj=True): 35 | # Nystrom method 36 | _, U = ker_svd_sparsify(pt, s - 1, kernel, device) 37 | w_star, idx_star = Mod_Tchernychova_Lyons( 38 | samp, U, pt, kernel, device, mu, use_obj=use_obj 39 | ) 40 | return idx_star, w_star 41 | 42 | 43 | def Mod_Tchernychova_Lyons(samp, U_svd, pt_nys, kernel, device, mu=0, use_obj=True, DEBUG=False): 44 | """ 45 | This function is a modified Tcherynychova_Lyons from 46 | https://github.com/FraCose/Recombination_Random_Algos/blob/master/recombination.py 47 | """ 48 | N = len(samp) 49 | n, length = U_svd.shape 50 | number_of_sets = 2 * (n + 1) 51 | 52 | # obj = torch.zeros(N).to(device) 53 | mu = torch.ones(N).to(device) / N 54 | 55 | idx_story = torch.arange(N).to(device) 56 | idx_story = idx_story[mu != 0] 57 | remaining_points = len(idx_story) 58 | 59 | while True: 60 | if remaining_points <= n + 1: 61 | idx_star = torch.arange(len(mu))[mu > 0].to(device) 62 | w_star = mu[idx_star] 63 | return w_star, idx_star 64 | 65 | elif n + 1 < remaining_points <= number_of_sets: 66 | X_mat = U_svd @ kernel(pt_nys, samp[idx_story]) 67 | w_star, idx_star, x_star, _, ERR, _, _ = Tchernychova_Lyons_CAR( 68 | X_mat.T, torch.clone(mu[idx_story]), device, DEBUG) 69 | idx_story = idx_story[idx_star] 70 | mu[:] = 0. 71 | mu[idx_story] = w_star 72 | idx_star = idx_story 73 | w_star = mu[mu > 0] 74 | return w_star, idx_star 75 | 76 | number_of_el = int(remaining_points / number_of_sets) 77 | 78 | idx = idx_story[:number_of_el * number_of_sets].reshape(number_of_el, -1) 79 | X_for_nys = torch.zeros((length, number_of_sets)).to(device) 80 | # X_for_obj = torch.zeros((1, number_of_sets)).to(device) 81 | for i in range(number_of_el): 82 | idx_tmp_i = idx_story[i * number_of_sets:(i + 1) * number_of_sets] 83 | X_for_nys += torch.multiply( 84 | kernel(pt_nys, samp[idx_tmp_i]), 85 | mu[idx_tmp_i].unsqueeze(0) 86 | ) 87 | 88 | X_tmp_tr = U_svd @ X_for_nys 89 | X_tmp = X_tmp_tr.T 90 | tot_weights = torch.sum(mu[idx], 0).to(device) 91 | idx_last_part = idx_story[number_of_el * number_of_sets:] 92 | 93 | if len(idx_last_part): 94 | X_mat = U_svd @ kernel(pt_nys, samp[idx_last_part]) 95 | X_tmp[-1] += torch.multiply( 96 | X_mat.T, 97 | mu[idx_last_part].unsqueeze(1) 98 | ).sum(axis=0) 99 | tot_weights[-1] += torch.sum(mu[idx_last_part], 0) 100 | 101 | X_tmp = torch.divide(X_tmp, tot_weights.unsqueeze(0).T) 102 | 103 | w_star, idx_star, _, _, ERR, _, _ = Tchernychova_Lyons_CAR( 104 | X_tmp, torch.clone(tot_weights), device 105 | ) 106 | 107 | idx_tomaintain = idx[:, idx_star].reshape(-1) 108 | idx_tocancel = torch.ones(idx.shape[1]).to(torch.bool).to(device) 109 | idx_tocancel[idx_star] = 0 110 | idx_tocancel = idx[:, idx_tocancel].reshape(-1) 111 | 112 | mu[idx_tocancel] = 0. 113 | mu_tmp = torch.multiply(mu[idx[:, idx_star]], w_star) 114 | mu_tmp = torch.divide(mu_tmp, tot_weights[idx_star]) 115 | mu[idx_tomaintain] = mu_tmp.reshape(-1) 116 | 117 | idx_tmp = idx_star == number_of_sets - 1 118 | idx_tmp = torch.arange(len(idx_tmp))[idx_tmp != 0].to(device) 119 | # if idx_star contains the last barycenter, whose set could have more points 120 | if len(idx_tmp) > 0: 121 | mu_tmp = torch.multiply(mu[idx_last_part], w_star[idx_tmp]) 122 | mu_tmp = torch.divide(mu_tmp, tot_weights[idx_star[idx_tmp]]) 123 | mu[idx_last_part] = mu_tmp 124 | idx_tomaintain = torch.cat([idx_tomaintain, idx_last_part]) 125 | else: 126 | idx_tocancel = torch.cat([idx_tocancel, idx_last_part]) 127 | mu[idx_last_part] = 0. 128 | 129 | idx_story = torch.clone(idx_tomaintain) 130 | remaining_points = len(idx_story) 131 | 132 | 133 | def Tchernychova_Lyons_CAR(X, mu, device, DEBUG=False): 134 | """ 135 | This functions reduce X from N points to n+1. 136 | This is taken from https://github.com/FraCose/Recombination_Random_Algos/blob/master/recombination.py 137 | """ 138 | X = torch.cat([torch.ones(X.size(0)).unsqueeze(0).T.to(device), X], dim=1) 139 | N, n = X.shape 140 | U, Sigma, V = torch.linalg.svd(X.T) 141 | U = torch.cat([U, torch.zeros((n, N - n)).to(device)], dim=1) 142 | Sigma = torch.cat([Sigma, torch.zeros(N - n).to(device)]) 143 | Phi = V[-(N - n):, :].T 144 | cancelled = torch.tensor([], dtype=int).to(device) 145 | 146 | for _ in range(N - n): 147 | lm = len(mu) 148 | plis = Phi[:, 0] > 0 149 | alpha = torch.zeros(lm).to(device) 150 | alpha[plis] = mu[plis] / Phi[plis, 0] 151 | idx = torch.arange(lm)[plis].to(device) 152 | idx = idx[torch.argmin(alpha[plis])] 153 | 154 | if len(cancelled) == 0: 155 | cancelled = idx.unsqueeze(0) 156 | else: 157 | cancelled = torch.cat([cancelled, idx.unsqueeze(0)]) 158 | mu[:] = mu - alpha[idx] * Phi[:, 0] 159 | mu[idx] = 0. 160 | 161 | if DEBUG and (not torch.allclose(torch.sum(mu), 1.)): 162 | # print("ERROR") 163 | print("sum ", torch.sum(mu)) 164 | 165 | Phi_tmp = Phi[:, 0] 166 | Phi = Phi[:, 1:] 167 | Phi = Phi - torch.matmul( 168 | Phi[idx].unsqueeze(1), 169 | Phi_tmp.unsqueeze(1).T, 170 | ).T / Phi_tmp[idx] 171 | Phi[idx, :] = 0. 172 | 173 | w_star = mu[mu > 0] 174 | idx_star = torch.arange(N)[mu > 0].to(device) 175 | return w_star, idx_star, torch.nan, torch.nan, 0., torch.nan, torch.nan 176 | -------------------------------------------------------------------------------- /BASQ/_scale_mmlt_wsabi.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from ._gp import update_gp 4 | from ._utils import Utils 5 | from ._gaussian_calc import GaussianCalc 6 | from ._mmlt_wsabi import MmltWsabiGP 7 | 8 | 9 | class ScaleMmltWsabiGP(MmltWsabiGP): 10 | def __init__( 11 | self, 12 | Xobs, 13 | Yobs, 14 | gp_kernel, 15 | device, 16 | label="wsabim", 17 | alpha_factor=1, 18 | lik=1e-10, 19 | training_iter=10000, 20 | thresh=0.01, 21 | lr=0.1, 22 | rng=10, 23 | train_lik=False, 24 | optimiser="L-BFGS-B", 25 | ): 26 | """ 27 | Scaled-MMLT-WSABI BQ modelling 28 | Scaled-MMLT-WSABI is a scaled MMLT-WSABI BQ modelling. 29 | Scaling protects GP modelling from wide dynamic range of log-likelihood. 30 | ____________________________________________ _______________________________________ __________________ 31 | | f space | g space | h space | 32 | |____________________________________________|_______________________________________|__________________| 33 | | f | g = sqrt(2(f - α)) | h = log(g + 1) | 34 | | f = α + 1/2g^2 | g = exp(h) - 1 | h | 35 | | f = α + 1/2 exp(h)exp(h) | g = exp(h) - 1 | h | 36 | |____________________________________________|_______________________________________|__________________| 37 | | f = GP(μ_f, σ_f) | g = GP(μ_g, σ_g) | h = GP(μ_h, σ_h) | 38 | | μ_f = α + 1/2(μ_g^2 + σ_g) | μ_g = exp(μ_h + 1/2 σ_h) - 1 | | 39 | | σ_f = 1/2σ_g(x,y)^2 + μ_g(x)σ_g(x,y)μ_g(y) | σ_g = μ_g(x)μ_g(y)(exp(σ_h(x,y)) - 1) | | 40 | |____________________________________________|_______________________________________|__________________| 41 | 42 | where β = max(Yobs_log), α = min(exp(Yobs_log - β)). 43 | - quadrature 44 | - log mean of marginal likelihood := log E[Z|f] + β 45 | - log variance of marginal likelihood := log Var[Z|f] + 2β 46 | 47 | WsabiMmltGP class summarises the functions of training, updating the warped GP model. 48 | This also provides the prediction and kernel of WSABI GP. 49 | The modelling of WSABI-L and WSABI-M can be easily switched by changing "label". 50 | The above table is the case of WSABI-M. 51 | 52 | Args: 53 | - Xobs: torch.tensor, X samples, X belongs to prior measure. 54 | - Yobs: torch.tensor, Y observations, Y = true_likelihood(X). 55 | - gp_kernel: gpytorch.kernels, GP kernel function 56 | - device: torch.device, device, cpu or cuda 57 | - label: string, the wsabi type, ["wsabil", "wsabim"] 58 | - lik: float, the initial value of GP likelihood noise variance 59 | - train_iter: int, the maximum iteration for GP hyperparameter training. 60 | - thresh: float, the threshold as a stopping criterion of GP hyperparameter training. 61 | - lr: float, the learning rate of Adam optimiser 62 | - rng: int, tne range coefficient of GP likelihood noise variance 63 | - train_like: bool, flag whether or not to update GP likelihood noise variance 64 | - optimiser: string, select the optimiser ["L-BFGS-B", "Adam"] 65 | """ 66 | self.gp_kernel = gp_kernel 67 | self.device = device 68 | self.alpha_factor = 1 69 | self.alpha = alpha_factor 70 | self.lik = lik 71 | self.training_iter = training_iter 72 | self.thresh = thresh 73 | self.lr = lr 74 | self.rng = rng 75 | self.train_lik = train_lik 76 | self.optimiser = optimiser 77 | 78 | self.jitter = 0 # 1e-6 79 | self.Y_log = copy.deepcopy(Yobs) 80 | self.utils = Utils(device) 81 | 82 | self.model = update_gp( 83 | Xobs, 84 | self.process_y_warping_with_scaling(Yobs), 85 | gp_kernel, 86 | self.device, 87 | lik=self.lik, 88 | training_iter=self.training_iter, 89 | thresh=self.thresh, 90 | lr=self.lr, 91 | rng=self.rng, 92 | train_lik=self.train_lik, 93 | optimiser=self.optimiser, 94 | ) 95 | self.gauss = GaussianCalc(self.model, self.device) 96 | 97 | def process_y_warping_with_scaling(self, y_obs): 98 | """ 99 | Args: 100 | - y_obs: torch.tensor, observations of true_loglikelihood 101 | 102 | Returns: 103 | - y_h: torch.tensor, warped observations in h space that contains no anomalies and the updated alpha hyperparameter. 104 | """ 105 | 106 | y = self.utils.remove_anomalies(y_obs) 107 | self.beta = torch.max(y) 108 | y_f = torch.exp(y - self.beta) 109 | self.alpha = self.alpha_factor * torch.min(y_f) 110 | y_h = self.warp_from_f_to_h(y_f) 111 | return y_h 112 | 113 | def cat_observations_with_scaling(self, X, Y): 114 | """ 115 | Args: 116 | - X: torch.tensor, X samples to be added to the existing data Xobs 117 | - Y: torch.tensor, unwarped Y observations to be added to the existing data Yobs 118 | 119 | Returns: 120 | - Xall: torch.tensor, X samples that contains all samples 121 | - Yall: torch.tensor, warped Y observations that contains all observations 122 | """ 123 | Xobs = self.model.train_inputs[0] 124 | Yobs_log = copy.deepcopy(self.Y_log) 125 | if len(self.model.train_targets.shape) == 0: 126 | Yobs_log = Yobs_log.unsqueeze(0) 127 | Xall = torch.cat([Xobs, X]) 128 | Yall_log = torch.cat([Yobs_log, Y]) 129 | self.Y_log = copy.deepcopy(Yall_log) 130 | Yall_h = self.process_y_warping_with_scaling(Yall_log) 131 | return Xall, Yall_h 132 | 133 | def update_mmlt_gp_with_scaling(self, X, Y): 134 | """ 135 | Args: 136 | - X: torch.tensor, X samples to be added to the existing data Xobs 137 | - Y: torch.tensor, unwarped Y observations to be added to the existing data Yobs 138 | """ 139 | X_h, Y_h = self.cat_observations_with_scaling(X, Y) 140 | self.model = update_gp( 141 | X_h, 142 | Y_h, 143 | self.gp_kernel, 144 | self.device, 145 | lik=self.lik, 146 | training_iter=self.training_iter, 147 | thresh=self.thresh, 148 | lr=self.lr, 149 | rng=self.rng, 150 | train_lik=self.train_lik, 151 | optimiser=self.optimiser, 152 | ) 153 | 154 | def retrain_gp_with_scaling(self): 155 | X_h = self.model.train_inputs[0] 156 | Y_h = self.process_y_warping_with_scaling(copy.deepcopy(self.Y_log)) 157 | self.model = update_gp( 158 | X_h, 159 | Y_h, 160 | self.gp_kernel, 161 | self.device, 162 | lik=self.lik, 163 | training_iter=self.training_iter, 164 | thresh=self.thresh, 165 | lr=self.lr, 166 | rng=self.rng, 167 | train_lik=self.train_lik, 168 | optimiser=self.optimiser, 169 | ) 170 | -------------------------------------------------------------------------------- /BASQ/_quadrature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ._gp import predict 3 | from ._rchq import recombination 4 | 5 | 6 | class KernelQuadrature: 7 | def __init__(self, n_rec, n_nys, n_quad, batch_size, sampler, kernel, device, mean_predict): 8 | """ 9 | Args: 10 | - n_rec: int, subsampling size for kernel recombination 11 | - nys_ratio: float, subsubsampling ratio for Nystrom. 12 | - n_nys: int, number of Nystrom samples; int(nys_ratio * n_rec) 13 | - n_quad: int, number of kernel recombination subsamples; int(quad_ratio * n_rec) 14 | - batch_size: int, batch size 15 | - sampler: function of samples = function(n_samples) 16 | - kernel: function of covariance_matrix = function(X, Y). Positive semi-definite Gram matrix (a.k.a. kernel) 17 | - device: torch.device, cpu or cuda 18 | - mean_predict: function of mean = function(x), the function that returns the predictive mean at given x 19 | """ 20 | self.n_rec = n_rec 21 | self.n_nys = n_nys 22 | self.n_quad = n_quad 23 | self.batch_size = batch_size 24 | self.sampler = sampler 25 | self.kernel = kernel 26 | self.device = device 27 | self.mean_predict = mean_predict 28 | 29 | def rchq(self, pts_nys, pts_rec, w_IS, batch_size, kernel): 30 | """ 31 | Args: 32 | - pts_nys: torch.tensor, subsamples for low-rank approximation via Nyström method 33 | - pts_rec: torch.tensor, subsamples for empirical measure of kernel recomnbination 34 | - w_IS: torch.tensor, weights for importance sampling if pts_rec is not sampled from the prior 35 | - batch_size: int, batch size 36 | - kernel: function of covariance_matrix = function(X, Y). Positive semi-definite Gram matrix (a.k.a. kernel) 37 | 38 | Returns: 39 | - x: torch.tensor, the sparcified samples from pts_rec. The number of samples are determined by self.batch_size 40 | - w: torch.tensor, the positive weights for kernel quadrature as discretised summation. 41 | """ 42 | idx, w = recombination( 43 | pts_rec, 44 | pts_nys, 45 | batch_size, 46 | kernel, 47 | self.device, 48 | init_weights=w_IS, 49 | ) 50 | x = pts_rec[idx] 51 | return x, w 52 | 53 | def quadrature(self): 54 | """ 55 | Returns: 56 | - EZy: float, the mean of the evidence 57 | - VarZy: float, the variance of the evidence 58 | """ 59 | pts_nys, pts_rec, w_IS = self.sampler(self.n_quad) 60 | X, w = self.rchq(pts_nys, pts_rec, w_IS, self.batch_size, self.kernel) 61 | self.EZy = (w @ self.mean_predict(X)).item() 62 | VarZy = (w @ self.kernel(X, X) @ w).item() 63 | print("E[Z|y]: " + str(self.EZy) + " Var[Z|y]: " + str(VarZy)) 64 | return self.EZy, VarZy 65 | 66 | def prior_max(self, mvn_max): 67 | """ 68 | Args: 69 | - mvn_max: torch.distributions, mutlivariate normal distribution of optimised prior distribution 70 | 71 | Returns: 72 | - EZy_prior: float, the mean of the evidence when the prior is optimised to maximise the evidence 73 | - VarZy_prior: float, the variance of the evidence when the prior is optimised to maximise the evidence 74 | """ 75 | pts_rec = mvn_max.sample(sample_shape=torch.Size([self.n_quad])) 76 | pts_nys = pts_rec[:self.n_nys] 77 | w_IS = torch.ones(self.n_quad) / self.n_quad 78 | 79 | X, w = self.rchq(pts_nys, pts_rec, w_IS, self.batch_size, self.kernel) 80 | EZy = (w @ self.mean_predict(X)).item() 81 | VarZy = (w @ self.kernel(X, X) @ w).item() 82 | print("prior maximisation") 83 | print("E[Z|y]: " + str(EZy) + " Var[Z|y]: " + str(VarZy)) 84 | return EZy, VarZy 85 | 86 | def uniform_trans(self, model_IS, uni_sampler): 87 | """ 88 | Args: 89 | - model_IS: gpytorch.models, function of GP model that assumes that 90 | prior is uniform distribution transformed via importance sampling 91 | - uni_sampler: function of samples = function(n_samples), uniform distribution sampler 92 | 93 | Returns: 94 | - EZy_uni float, the mean of the evidence when the prior is transformed into uniform distribution 95 | - VarZy_uni: float, the variance of the evidence when the prior is transformed into uniform distribution 96 | """ 97 | pts_rec = uni_sampler(self.n_quad) 98 | pts_nys = pts_rec[:self.n_nys] 99 | w_IS = torch.ones(self.n_quad) / self.n_quad 100 | 101 | X, w = self.rchq(pts_nys, pts_rec, w_IS, self.batch_size, model_IS.covar_module.forward) 102 | mean, _ = predict(X, model_IS) 103 | EZy = (w @ mean).item() 104 | VarZy = (w @ model_IS.covar_module.forward(X, X) @ w).item() 105 | print("uniform transformation") 106 | print("E[Z|y]: " + str(EZy) + " Var[Z|y]: " + str(VarZy)) 107 | return EZy, VarZy 108 | 109 | 110 | class ScaleKernelQuadrature(KernelQuadrature): 111 | def __init__(self, n_rec, n_nys, n_quad, batch_size, sampler, gp, device): 112 | """ 113 | Args: 114 | - n_rec: int, subsampling size for kernel recombination 115 | - nys_ratio: float, subsubsampling ratio for Nystrom. 116 | - n_nys: int, number of Nystrom samples; int(nys_ratio * n_rec) 117 | - n_quad: int, number of kernel recombination subsamples; int(quad_ratio * n_rec) 118 | - batch_size: int, batch size 119 | - sampler: function of samples = function(n_samples) 120 | - kernel: function of covariance_matrix = function(X, Y). Positive semi-definite Gram matrix (a.k.a. kernel) 121 | - device: torch.device, cpu or cuda 122 | - mean_predict: function of mean = function(x), the function that returns the predictive mean at given x 123 | """ 124 | self.n_rec = n_rec 125 | self.n_nys = n_nys 126 | self.n_quad = n_quad 127 | self.batch_size = batch_size 128 | self.sampler = sampler 129 | self.gp = gp 130 | self.kernel = self.gp.fspace_kernel 131 | self.device = device 132 | self.mean_predict = self.gp.fspace_mean_predict 133 | 134 | def rchq(self, pts_nys, pts_rec, w_IS, batch_size, kernel): 135 | """ 136 | Args: 137 | - pts_nys: torch.tensor, subsamples for low-rank approximation via Nyström method 138 | - pts_rec: torch.tensor, subsamples for empirical measure of kernel recomnbination 139 | - w_IS: torch.tensor, weights for importance sampling if pts_rec is not sampled from the prior 140 | - batch_size: int, batch size 141 | - kernel: function of covariance_matrix = function(X, Y). Positive semi-definite Gram matrix (a.k.a. kernel) 142 | 143 | Returns: 144 | - x: torch.tensor, the sparcified samples from pts_rec. The number of samples are determined by self.batch_size 145 | - w: torch.tensor, the positive weights for kernel quadrature as discretised summation. 146 | """ 147 | idx, w = recombination( 148 | pts_rec, 149 | pts_nys, 150 | batch_size, 151 | kernel, 152 | self.device, 153 | init_weights=w_IS, 154 | ) 155 | x = pts_rec[idx] 156 | return x, w 157 | 158 | def quadrature(self): 159 | """ 160 | Returns: 161 | - EZy: float, the mean of the evidence 162 | - VarZy: float, the variance of the evidence 163 | """ 164 | pts_nys, pts_rec, w_IS = self.sampler(self.n_quad) 165 | X, w = self.rchq(pts_nys, pts_rec, w_IS, self.batch_size, self.kernel) 166 | EZy = (w @ self.mean_predict(X)) 167 | VarZy = (w @ self.kernel(X, X) @ w) 168 | self.logEZy = (EZy.log() + self.gp.beta).item() 169 | logVarZy = VarZy.abs().log().item() 170 | #logVarZy = (VarZy.abs().log() + 2 * self.gp.beta).item() 171 | print("logE[Z|y]: " + str(self.logEZy) + " logVar[Z|y]: " + str(logVarZy)) 172 | return self.logEZy, logVarZy 173 | -------------------------------------------------------------------------------- /BASQ/_gp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | import gpytorch 4 | from ._lbfgs import FullBatchLBFGS 5 | from botorch.fit import fit_gpytorch_model 6 | from gpytorch.priors.torch_priors import GammaPrior 7 | 8 | 9 | class ExactGPModel(gpytorch.models.ExactGP): 10 | def __init__(self, train_x, train_y, likelihood, gp_kernel): 11 | """ 12 | Args: 13 | - train_x: torch.tensor, inputs. torch.Size(n_data, n_dims) 14 | - train_y: torch.tensor, observations 15 | - likelihood: gpytorch.likelihoods, GP likelihood model 16 | - gp_kernel: gpytorch.kernels, GP kernel model 17 | """ 18 | super(ExactGPModel, self).__init__(train_x, train_y, likelihood) 19 | self.mean_module = gpytorch.means.ConstantMean() 20 | self.covar_module = gp_kernel 21 | 22 | def forward(self, x): 23 | """ 24 | Args: 25 | - x: torch.tensor, inputs. torch.Size(n_data, n_dims) 26 | 27 | Returns: 28 | - torch.distributions, predictive posterior distribution at given x 29 | """ 30 | mean_x = self.mean_module(x) 31 | covar_x = self.covar_module(x) 32 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) 33 | 34 | 35 | def set_gp(train_x, train_y, gp_kernel, device, lik=1e-10, rng=10, train_lik=False): 36 | """ 37 | We can select whether or not to train likelihood variance. 38 | The true_likelihood query must be noiseless, so learning GP likelihood noise variance could be redundant. 39 | However, likelihood noise variance plays an important role in a limited number of samples in the early stage. 40 | So, setting interval constraints keeps the likelihood noise variance within a safe area. 41 | Otherwise, GP could confuse the meaningful multimodal peaks of true_likelihood as noise. 42 | 43 | Args: 44 | - train_x: torch.tensor, inputs. torch.Size(n_data, n_dims) 45 | - train_y: torch.tensor, observations 46 | - gp_kernel: gpytorch.kernels, GP kernel model 47 | - device: torch.device, cpu or cuda 48 | - lik: float, the initial value of GP likelihood noise variance 49 | - rng: int, tne range coefficient of GP likelihood noise variance 50 | - train_like: bool, flag whether or not to update GP likelihood noise variance 51 | 52 | Returns: 53 | - model: gpytorch.models, function of GP model. 54 | """ 55 | likelihood = gpytorch.likelihoods.GaussianLikelihood() 56 | likelihood.noise_covar.register_constraint("raw_noise", gpytorch.constraints.Interval(lik / rng, lik * rng)) 57 | model = ExactGPModel(train_x, train_y, likelihood, gp_kernel) 58 | model.covar_module.base_kernel.lengthscale_prior = GammaPrior(3.0, 6.0) 59 | model.covar_module.outputscale_prior = GammaPrior(2.0, 0.15) 60 | hypers = { 61 | 'likelihood.noise_covar.noise': torch.tensor(lik), 62 | } 63 | 64 | model.initialize(**hypers) 65 | if not train_lik: 66 | model.likelihood.raw_noise.requires_grad = False 67 | 68 | if device.type == 'cuda': 69 | model = model.cuda() 70 | model.likelihood = model.likelihood.cuda() 71 | return model 72 | 73 | 74 | class Closure: 75 | """ 76 | Args: 77 | - mll: gpytorch.mlls.ExactMarginalLogLikelihood, marginal log likelihood 78 | - optimiser: torch.optim, L-BFGS-B optimizer from FullBatchLBFGS 79 | 80 | Returns: 81 | - loss: torch.tensor, negative log marginal likelihood of GP 82 | """ 83 | def __init__(self, mll, optimizer): 84 | self.mll = mll 85 | self.optimizer = optimizer 86 | self.train_inputs, self.train_targets = mll.model.train_inputs, mll.model.train_targets 87 | 88 | def __call__(self): 89 | self.optimizer.zero_grad() 90 | with gpytorch.settings.fast_computations(log_prob=True): 91 | output = self.mll.model(*self.train_inputs) 92 | args = [output, self.train_targets] 93 | loss = -self.mll(*args).sum() 94 | return loss 95 | 96 | 97 | def train_GP_with_BFGS(mll, training_iter, thresh): 98 | """ 99 | L-BFGS-B implementation is from https://github.com/hjmshi/PyTorch-LBFGS 100 | 101 | Args: 102 | - mll: gpytorch.mlls.ExactMarginalLogLikelihood, marginal log likelihood 103 | - training_iter: int, the maximum number of iteration of optimisation loop 104 | - thresh: float, the stopping criterion 105 | 106 | Returns: 107 | - mll: gpytorch.mlls.ExactMarginalLogLikelihood, marginal log likelihood 108 | """ 109 | # Use full-batch L-BFGS optimizer 110 | optimizer = FullBatchLBFGS(mll.model.parameters()) 111 | closure = Closure(mll, optimizer) 112 | loss = closure() 113 | loss.backward() 114 | loss_best = torch.tensor(1e10) 115 | 116 | for i in range(training_iter): 117 | # perform step and update curvature 118 | options = {'closure': closure, 'current_loss': loss, 'max_ls': 10} 119 | loss, _, lr, _, F_eval, G_eval, _, _ = optimizer.step(options) 120 | 121 | if loss.item() < loss_best: 122 | delta = torch.abs(loss_best - loss.detach()) 123 | loss_best = loss.item() 124 | if delta < thresh: 125 | break 126 | return mll 127 | 128 | 129 | def train_GP_with_Adam(mll, lr, training_iter, thresh): 130 | """ 131 | Args: 132 | - mll: gpytorch.mlls.ExactMarginalLogLikelihood, marginal log likelihood 133 | - lr: float, the learning rate 134 | - training_iter: int, the maximum number of iteration of optimisation loop 135 | - thresh: float, the stopping criterion 136 | 137 | Returns: 138 | - mll: gpytorch.mlls.ExactMarginalLogLikelihood, marginal log likelihood 139 | """ 140 | optimizer = torch.optim.Adam(mll.model.parameters(), lr=lr) 141 | train_x = mll.model.train_inputs[0] 142 | train_y = mll.model.train_targets 143 | loss_best = torch.tensor(1e10) 144 | 145 | for i in range(training_iter): 146 | optimizer.zero_grad() 147 | output = mll.model(train_x) 148 | loss = -mll(output, train_y) 149 | loss.backward() 150 | optimizer.step() 151 | if loss.item() < loss_best: 152 | delta = torch.abs(loss_best - loss.detach()) 153 | loss_best = loss.item() 154 | if delta < thresh: 155 | break 156 | return mll 157 | 158 | 159 | def train_GP(model, training_iter=50, thresh=0.01, lr=0.1, optimiser="L-BFGS-B"): 160 | """ 161 | Args: 162 | - model: gpytorch.models, function of GP model. 163 | - train_iter: int, the maximum iteration for GP hyperparameter training. 164 | - thresh: float, the threshold as a stopping criterion of GP hyperparameter training. 165 | - lr: float, the learning rate of Adam optimiser 166 | - optimiser: string, select the optimiser ["L-BFGS-B", "BoTorch", "Adam"] 167 | 168 | Returns: 169 | - model: gpytorch.models, function of GP model. 170 | """ 171 | model.train() 172 | model.likelihood.train() 173 | mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model) 174 | try: 175 | if optimiser == "BoTorch": 176 | mll = fit_gpytorch_model(mll) 177 | elif optimiser == "L-BFGS-B": 178 | mll = train_GP_with_BFGS(mll, training_iter, thresh) 179 | 180 | elif optimiser == "Adam": 181 | mll = train_GP_with_Adam(mll, lr, training_iter, thresh) 182 | else: 183 | raise Exception("The given optimiser is not defined") 184 | except: 185 | warnings.warn("Optimiser " + optimiser + " failed. Optimising again with Adam...") 186 | mll = train_GP_with_Adam(mll, lr, training_iter, thresh) 187 | return model 188 | 189 | 190 | def update_gp(train_x, train_y, gp_kernel, device, lik=1e-10, training_iter=50, thresh=0.01, lr=0.1, rng=10, train_lik=False, optimiser="L-BFGS-B"): 191 | """ 192 | Input: 193 | - train_x: torch.tensor, inputs. torch.Size(n_data, n_dims) 194 | - train_y: torch.tensor, observations 195 | - gp_kernel: gpytorch.kernels, GP kernel model 196 | - device: torch.device, cpu or cuda 197 | - lik: float, the initial value of GP likelihood noise variance 198 | - train_iter: int, the maximum iteration for GP hyperparameter training. 199 | - thresh: float, the threshold as a stopping criterion of GP hyperparameter training. 200 | - lr: float, the learning rate of Adam optimiser 201 | - rng: int, tne range coefficient of GP likelihood noise variance 202 | - train_like: bool, flag whether or not to update GP likelihood noise variance 203 | - optimiser: string, select the optimiser ["L-BFGS-B", "BoTorch", "Adam"] 204 | 205 | Output: 206 | - model: gpytorch.models, function of GP model. 207 | """ 208 | model = set_gp(train_x, train_y, gp_kernel, device, lik=lik, rng=rng, train_lik=train_lik) 209 | model = train_GP(model, training_iter=training_iter, thresh=thresh, lr=lr, optimiser=optimiser) 210 | return model 211 | 212 | 213 | def predict(test_x, model): 214 | """ 215 | Fast variance inference is made with LOVE via fast_pred_var(). 216 | For accurate variance inference, you can just comment out the part. 217 | 218 | Input: 219 | - model: gpytorch.models, function of GP model. 220 | 221 | Output: 222 | - pred.mean; torch.tensor, the predictive mean 223 | - pred.variance; torch.tensor, the predictive variance 224 | """ 225 | model.eval() 226 | model.likelihood.eval() 227 | 228 | try: 229 | with torch.no_grad(), gpytorch.settings.fast_pred_var(): 230 | pred = model.likelihood(model(test_x)) 231 | except: 232 | warnings.warn("Cholesky failed. Adding more jitter...") 233 | with torch.no_grad(), gpytorch.settings.cholesky_jitter(float=1e-2): 234 | pred = model.likelihood(model(test_x)) 235 | return pred.mean, pred.variance 236 | 237 | 238 | def get_cov_cache(model): 239 | """ 240 | woodbury_inv = K(Xobs, Xobs)^(-1) 241 | S @ S.T = woodbury_inv 242 | 243 | Input: 244 | - model: gpytorch.models, function of GP model, typically self.wsabi.model in _basq.py 245 | 246 | Output: 247 | - woodbury_inv: torch.tensor, the inverse of Gram matrix K(Xobs, Xobs)^(-1) 248 | - Xobs: torch.tensor, the observed inputs X 249 | - lik_var: torch.tensor, the GP likelihood noise variance 250 | """ 251 | Xobs = model.train_inputs[0] 252 | lik_var = model.likelihood.noise 253 | try: 254 | S = model.prediction_strategy.covar_cache 255 | except: 256 | model.eval() 257 | mean = Xobs[0].unsqueeze(0) 258 | model(mean) 259 | S = model.prediction_strategy.covar_cache 260 | woodbury_inv = S @ S.T 261 | return woodbury_inv, Xobs, lik_var 262 | 263 | 264 | def predictive_covariance(x, y, model): 265 | """ 266 | Input: 267 | - x: torch.tensor, inputs x 268 | - y: torch.tensor, inputs y 269 | - model: gpytorch.models, function of GP model. 270 | 271 | Output: 272 | - cov_xy: torch.tensor, predictive covariance matrix 273 | """ 274 | woodbury_inv, Xobs, lik_var = get_cov_cache(model) 275 | Kxy = model.covar_module.forward(x, y) 276 | KxX = model.covar_module.forward(x, Xobs) 277 | KXy = model.covar_module.forward(Xobs, y) 278 | cov_xy = Kxy - KxX @ woodbury_inv @ KXy 279 | 280 | d = min(len(x), len(y)) 281 | cov_xy[range(d), range(d)] = cov_xy[range(d), range(d)] + lik_var 282 | return cov_xy 283 | -------------------------------------------------------------------------------- /BASQ/_wsabi.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from ._gp import update_gp, predict, predictive_covariance 4 | from ._utils import Utils 5 | from ._gaussian_calc import GaussianCalc 6 | 7 | 8 | class WsabiGP: 9 | def __init__( 10 | self, 11 | Xobs, 12 | Yobs, 13 | gp_kernel, 14 | device, 15 | label="wsabim", 16 | alpha_factor=0.8, 17 | lik=1e-10, 18 | training_iter=10000, 19 | thresh=0.01, 20 | lr=0.1, 21 | rng=10, 22 | train_lik=False, 23 | optimiser="L-BFGS-B", 24 | ): 25 | """ 26 | WSABI BQ modelling 27 | WsabiGP class summarises the functions of training, updating the warped GP model. 28 | This also provides the prediction and kernel of WSABI GP. 29 | The modelling of WSABI-L and WSABI-M can be easily switched by changing "label". 30 | 31 | Args: 32 | - Xobs: torch.tensor, X samples, X belongs to prior measure. 33 | - Yobs: torch.tensor, Y observations, Y = true_likelihood(X). 34 | - gp_kernel: gpytorch.kernels, GP kernel function 35 | - device: torch.device, device, cpu or cuda 36 | - label: string, the wsabi type, ["wsabil", "wsabim"] 37 | - lik: float, the initial value of GP likelihood noise variance 38 | - train_iter: int, the maximum iteration for GP hyperparameter training. 39 | - thresh: float, the threshold as a stopping criterion of GP hyperparameter training. 40 | - lr: float, the learning rate of Adam optimiser 41 | - rng: int, tne range coefficient of GP likelihood noise variance 42 | - train_like: bool, flag whether or not to update GP likelihood noise variance 43 | - optimiser: string, select the optimiser ["L-BFGS-B", "Adam"] 44 | """ 45 | self.gp_kernel = gp_kernel 46 | self.device = device 47 | self.alpha_factor = alpha_factor 48 | self.lik = lik 49 | self.training_iter = training_iter 50 | self.thresh = thresh 51 | self.lr = lr 52 | self.rng = rng 53 | self.train_lik = train_lik 54 | self.optimiser = optimiser 55 | 56 | self.jitter = 0 # 1e-6 57 | self.Y_unwarp = copy.deepcopy(Yobs) 58 | self.utils = Utils(device) 59 | 60 | self.model = update_gp( 61 | Xobs, 62 | self.process_y_warping(Yobs), 63 | gp_kernel, 64 | self.device, 65 | lik=self.lik, 66 | training_iter=self.training_iter, 67 | thresh=self.thresh, 68 | lr=self.lr, 69 | rng=self.rng, 70 | train_lik=self.train_lik, 71 | optimiser=self.optimiser, 72 | ) 73 | self.setting(label) 74 | self.gauss = GaussianCalc(self.model, self.device) 75 | 76 | def setting(self, label): 77 | """ 78 | Args: 79 | - label: string, the wsabi type, ["wsabil", "wsabim"] 80 | """ 81 | if label == "wsabil": 82 | self.kernel = self.wsabil_kernel 83 | self.predict = self.wsabil_predict 84 | self.predict_mean = self.wsabil_mean_predict 85 | elif label == "wsabim": 86 | self.kernel = self.wsabim_kernel 87 | self.predict = self.wsabim_predict 88 | self.predict_mean = self.wsabim_mean_predict 89 | 90 | def warp_y(self, y): 91 | """ 92 | Args: 93 | - y: torch.tensor, observations 94 | 95 | Returns: 96 | - y: torch.tensor, warped observations 97 | """ 98 | return torch.sqrt(2 * (y - self.alpha)) 99 | 100 | def unwarp_y(self, y): 101 | """ 102 | Args: 103 | - y: torch.tensor, warped observations 104 | 105 | Returns: 106 | - y: torch.tensor, unwarped observations 107 | """ 108 | return self.alpha + 0.5 * (y ** 2) 109 | 110 | def process_y_warping(self, y): 111 | """ 112 | Args: 113 | - y: torch.tensor, observations 114 | 115 | Returns: 116 | - y: torch.tensor, warped observations that contains no anomalies and the updated alpha hyperparameter. 117 | """ 118 | y = self.utils.remove_anomalies(y) 119 | self.alpha = self.alpha_factor * torch.min(y) 120 | y = self.warp_y(y) 121 | return y 122 | 123 | def cat_observations(self, X, Y): 124 | """ 125 | Args: 126 | - X: torch.tensor, X samples to be added to the existing data Xobs 127 | - Y: torch.tensor, unwarped Y observations to be added to the existing data Yobs 128 | 129 | Returns: 130 | - Xall: torch.tensor, X samples that contains all samples 131 | - Yall: torch.tensor, warped Y observations that contains all observations 132 | """ 133 | Xobs = self.model.train_inputs[0] 134 | Yobs = copy.deepcopy(self.Y_unwarp) 135 | if len(self.model.train_targets.shape) == 0: 136 | Yobs = Yobs.unsqueeze(0) 137 | Xall = torch.cat([Xobs, X]) 138 | _Yall = torch.cat([Yobs, Y]) 139 | self.Y_unwarp = copy.deepcopy(_Yall) 140 | Yall = self.process_y_warping(_Yall) 141 | return Xall, Yall 142 | 143 | def update_wsabi_gp(self, X, Y): 144 | """ 145 | Args: 146 | - X: torch.tensor, X samples to be added to the existing data Xobs 147 | - Y: torch.tensor, unwarped Y observations to be added to the existing data Yobs 148 | """ 149 | X_warp, Y_warp = self.cat_observations(X, Y) 150 | self.model = update_gp( 151 | X_warp, 152 | Y_warp, 153 | self.gp_kernel, 154 | self.device, 155 | lik=self.lik, 156 | training_iter=self.training_iter, 157 | thresh=self.thresh, 158 | lr=self.lr, 159 | rng=self.rng, 160 | train_lik=self.train_lik, 161 | optimiser=self.optimiser, 162 | ) 163 | 164 | def retrain_gp(self): 165 | X_warp = self.model.train_inputs[0] 166 | Y_warp = self.process_y_warping(copy.deepcopy(self.Y_unwarp)) 167 | self.model = update_gp( 168 | X_warp, 169 | Y_warp, 170 | self.gp_kernel, 171 | self.device, 172 | lik=self.lik, 173 | training_iter=self.training_iter, 174 | thresh=self.thresh, 175 | lr=self.lr, 176 | rng=self.rng, 177 | train_lik=self.train_lik, 178 | optimiser=self.optimiser, 179 | ) 180 | 181 | def memorise_parameters(self): 182 | self.likelihood_memory = copy.deepcopy(torch.tensor(self.model.likelihood.noise.item())) 183 | self.outputsacle_memory = copy.deepcopy(torch.tensor(self.model.covar_module.outputscale.item())) 184 | self.lengthscale_memory = copy.deepcopy(torch.tensor(self.model.covar_module.base_kernel.lengthscale.item())) 185 | 186 | def remind_parameters(self): 187 | hypers = { 188 | 'likelihood.noise_covar.noise': self.likelihood_memory, 189 | 'covar_module.outputscale': self.outputsacle_memory, 190 | 'covar_module.base_kernel.lengthscale': self.lengthscale_memory, 191 | } 192 | self.model.initialize(**hypers) 193 | 194 | def predictive_kernel(self, x, y): 195 | """ 196 | Args: 197 | - x: torch.tensor, x locations to be predicted 198 | - y: torch.tensor, y locations to be predicted 199 | 200 | Args: 201 | - CLy: torch.tensor, the positive semi-definite Gram matrix of predictive variance 202 | """ 203 | return predictive_covariance(x, y, self.model) 204 | 205 | def wsabil_kernel(self, x, y): 206 | """ 207 | Args: 208 | - x: torch.tensor, x locations to be predicted 209 | - y: torch.tensor, y locations to be predicted 210 | 211 | Returns: 212 | - CLy: torch.tensor, the positive semi-definite Gram matrix of WSABI-L variance 213 | """ 214 | mu_x, _ = predict(x, self.model) 215 | mu_y, _ = predict(y, self.model) 216 | cov_xy = predictive_covariance(x, y, self.model) 217 | CLy = mu_x.unsqueeze(1) * cov_xy * mu_y.unsqueeze(0) 218 | 219 | d = min(len(x), len(y)) 220 | CLy[range(d), range(d)] = CLy[range(d), range(d)] + self.jitter 221 | return CLy 222 | 223 | def wsabim_kernel(self, x, y): 224 | """ 225 | Args: 226 | - x: torch.tensor, x locations to be predicted 227 | - y: torch.tensor, y locations to be predicted 228 | 229 | Returns: 230 | - CLy: torch.tensor, the positive semi-definite Gram matrix of WSABI-M variance 231 | """ 232 | mu_x, _ = predict(x, self.model) 233 | mu_y, _ = predict(y, self.model) 234 | cov_xy = predictive_covariance(x, y, self.model) 235 | CLy = mu_x.unsqueeze(1) * cov_xy * mu_y.unsqueeze(0) + 0.5 * (cov_xy ** 2) 236 | 237 | d = min(len(x), len(y)) 238 | CLy[range(d), range(d)] = CLy[range(d), range(d)] + self.jitter 239 | return CLy 240 | 241 | def wsabil_predict(self, x): 242 | """ 243 | Args: 244 | - x: torch.tensor, x locations to be predicted 245 | 246 | Returns: 247 | - mu: torch.tensor, unwarped predictive mean at given locations x. 248 | - var: torch.tensor, unwarped predictive variance at given locations x. 249 | """ 250 | mu_warp, var_warp = predict(x, self.model) 251 | mu = self.alpha + 0.5 * mu_warp**2 252 | var = mu_warp * var_warp * mu_warp 253 | return mu, var 254 | 255 | def wsabim_predict(self, x): 256 | """ 257 | Args: 258 | - x: torch.tensor, x locations to be predicted 259 | 260 | Returns: 261 | - mu: torch.tensor, unwarped predictive mean at given locations x. 262 | - var: torch.tensor, unwarped predictive variance at given locations x. 263 | """ 264 | mu_warp, var_warp = predict(x, self.model) 265 | mu = self.alpha + 0.5 * (mu_warp**2 + var_warp) 266 | var = mu_warp * var_warp * mu_warp + 0.5 * (var_warp ** 2) 267 | return mu, var 268 | 269 | def wsabil_mean_predict(self, x): 270 | """ 271 | Args: 272 | - x: torch.tensor, x locations to be predicted 273 | 274 | Returns: 275 | - mu: torch.tensor, unwarped predictive mean at given locations x. 276 | """ 277 | mu_warp, _ = predict(x, self.model) 278 | mu = self.alpha + 0.5 * mu_warp**2 279 | return mu 280 | 281 | def wsabim_mean_predict(self, x): 282 | """ 283 | Args: 284 | - x: torch.tensor, x locations to be predicted 285 | 286 | Returns: 287 | - mu: torch.tensor, unwarped predictive mean at given locations x. 288 | """ 289 | mu_warp, var_warp = predict(x, self.model) 290 | mu = self.alpha + 0.5 * (mu_warp**2 + var_warp) 291 | return mu 292 | 293 | def unimodal_approximation(self): 294 | """ 295 | Approximating WSABI-GP with unimodal multivariate normal distribution. 296 | This is equivalent to maximising posterior w.r.t prior distribution. 297 | The maximisation of posterior can be achieved when prior is fitted to likelihood. 298 | Such calculation can be done analytically. 299 | 300 | Returns: 301 | - mvn_pi_max: torch.distributions, mutlivariate normal distribution of optimised prior 302 | """ 303 | return self.gauss.unimodal_approximation(self.model, self.alpha) 304 | 305 | def uniform_transformation(self, prior): 306 | """ 307 | Estimating the evidence with uniform prior as post-process. 308 | By adopting importance sampling, we can estimate the evidence with arbitrary prior. 309 | ∫l(x)π(x) = ∫l(x)π(x)/g(x) g(x)dx = ∫l'(x)g(x)dx, 310 | where π(x) is the uniform prior, g(x) is the Gaussian proposal distribution. 311 | 312 | Args: 313 | - prior: torch.distributions, prior distribution 314 | 315 | Returns: 316 | - model_IS: gpytorch.models, function of GP model which is transformed into uniform distribution 317 | - uni_sampler: function of samples = function(n_samples), a uniform distribution sampler 318 | """ 319 | Xobs_uni, Yobs_uni, uni_sampler, uni_logpdf = self.gauss.uniform_transformation( 320 | self.model, 321 | self.Y_unwarp, 322 | ) 323 | 324 | Y_IS = torch.exp(torch.log(Yobs_uni) + prior.log_prob(Xobs_uni) - uni_logpdf(Xobs_uni)) 325 | Y_IS = self.utils.remove_anomalies(Y_IS) 326 | model_IS = update_gp( 327 | Xobs_uni, 328 | Y_IS.detach(), 329 | self.gp_kernel, 330 | self.device, 331 | lik=self.lik, 332 | training_iter=self.training_iter, 333 | thresh=self.thresh, 334 | lr=self.lr, 335 | rng=self.rng, 336 | train_lik=self.train_lik, 337 | optimiser=self.optimiser, 338 | ) 339 | return model_IS, uni_sampler 340 | -------------------------------------------------------------------------------- /BASQ/experiment/ecm.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import matplotlib.pyplot as plt 4 | from torch.distributions.normal import Normal 5 | 6 | 7 | class CanonicalECMTwoRCs: 8 | def __init__(self, rt, r1_, t1, r2_, t2, sigma, omega): 9 | """ 10 | Args: 11 | - Rt: torch.tensor, the total resistance of the battery (aka R at f=0 [Hz]) 12 | - rt: torch.tensor, the normalised DC resistance 13 | - ri: torch.tensor, the normalised AC resistance of the i-th RC pair 14 | - ti: torch.tensor, the normalised time constant of the i-th RC pair 15 | - sigma: torch.tensor, experimental noise variance 16 | - omega: torch.tensor, angular frequency [rad/s] 17 | """ 18 | self.omega = omega 19 | self.noise_sig = torch.tensor(sigma) 20 | self.normalise_freq() 21 | self.set_parameters(rt, r1_, t1, r2_, t2) 22 | self.synthetic_data(self.noise_sig) 23 | 24 | def set_parameters(self, rt, r1_, t1, r2_, t2): 25 | """ 26 | Args: 27 | - rt: torch.tensor, the normalised DC resistance 28 | - ri: torch.tensor, the normalised AC resistance of the i-th RC pair 29 | - ti: torch.tensor, the normalised time constant of the i-th RC pair 30 | """ 31 | self.rt = rt 32 | self.t1 = t1 33 | self.r1 = torch.exp(-torch.exp(r1_)) 34 | self.t2 = t2 35 | self.r2 = torch.exp(-torch.exp(r2_)) 36 | self.Rt = torch.exp(self.rt) 37 | self.r0 = 1 - self.r1 - self.r2 38 | 39 | def normalise_freq(self): 40 | self.mu = torch.mean(torch.log(self.omega)) 41 | self.sigma = torch.std(torch.log(self.omega)) 42 | 43 | def unnormalise_tau(self, tau): 44 | """ 45 | Args: 46 | - tau: torch.tensor, time constant tau in log space; tau = ln(omega * t_i) 47 | Returns: 48 | - tau: torch.tensor, time constant tau in raw space; tau = omega * t_i 49 | """ 50 | return torch.exp(-(self.sigma * tau + self.mu)) 51 | 52 | def normalised_input(self, tau): 53 | """ 54 | Args: 55 | - tau: torch.tensor, time constant tau in raw space; tau = omega * t_i 56 | Returns: 57 | - tau: torch.tensor, time constant tau in log space; tau = ln(omega * t_i) 58 | """ 59 | return torch.log(self.omega) - (self.sigma * tau + self.mu) 60 | 61 | def real_part(self): 62 | """ 63 | Returns: 64 | - Z.real: torch.tensor, real part of impedance spectrum 65 | """ 66 | return self.Rt * ( 67 | self.r0 + self.r1 / 2 * (1 - torch.tanh(self.normalised_input(self.t1))) 68 | + self.r2 / 2 * (1 - torch.tanh(self.normalised_input(self.t2))) 69 | ) 70 | 71 | def imarginary_part(self): 72 | """ 73 | Returns: 74 | - Z.imaginary: torch.tensor, imaginary part of impedance spectrum 75 | """ 76 | return self.Rt * ( 77 | (self.r1 / 2) / torch.cosh(self.normalised_input(self.t1)) 78 | + (self.r2 / 2) / torch.cosh(self.normalised_input(self.t2)) 79 | ) 80 | 81 | def synthetic_data(self, sigma): 82 | """ 83 | Args: 84 | - sigma: torch.tensor, experimental noise variance 85 | """ 86 | R = torch.exp(-torch.exp(sigma)) 87 | self.reZ = self.real_part() + Normal(0, 1).sample(torch.Size([len(self.omega)])) * torch.sqrt(R) 88 | self.imZ = self.imarginary_part() + Normal(0, 1).sample(torch.Size([len(self.omega)])) * torch.sqrt(R) 89 | self.Rt_syn = copy.deepcopy(self.Rt) 90 | self.LL = self.loglikelihood(sigma) 91 | 92 | def loglikelihood(self, sigma): 93 | """ 94 | Args: 95 | - sigma: torch.tensor, experimental noise variance 96 | Returns: 97 | - LL: torch.tensor, log-likelihood 98 | """ 99 | R = torch.exp(-torch.exp(sigma)) 100 | err_reZ = self.reZ - self.real_part() 101 | err_imZ = self.imZ - self.imarginary_part() 102 | err = (err_reZ @ err_reZ + err_imZ @ err_imZ) 103 | LL = -torch.log(2 * torch.pi * R) * len(self.omega) - 0.5 * err / R 104 | return LL 105 | 106 | def convert_circuit_elements(self): 107 | """ 108 | Returns: 109 | - R0: torch.tensor, the unnormalised DC resistance 110 | - Ri: torch.tensor, the unnormalised AC resistance of the i-th RC pair 111 | - Ci: torch.tensor, the unnormalised capacitance of the i-th RC pair 112 | """ 113 | R0 = self.Rt * self.r0 114 | R1 = self.Rt * self.r1 115 | R2 = self.Rt * self.r2 116 | lnt1 = torch.exp(-(self.sigma * self.t1 + self.mu)) 117 | C1 = lnt1 / R1 118 | lnt2 = torch.exp(-(self.sigma * self.t2 + self.mu)) 119 | C2 = lnt2 / R2 120 | return R0, R1, C1, R2, C2 121 | 122 | def plot(self): 123 | # without noise 124 | plt.scatter(self.real_part(), self.imarginary_part()) 125 | plt.show() 126 | plt.scatter(torch.log10(self.omega / (2 * torch.pi)), self.real_part()) 127 | plt.scatter(torch.log10(self.omega / (2 * torch.pi)), self.imarginary_part()) 128 | plt.show() 129 | 130 | # with noise 131 | plt.scatter(self.reZ, self.imZ) 132 | plt.show() 133 | plt.scatter(torch.log10(self.omega / (2 * torch.pi)), self.reZ) 134 | plt.scatter(torch.log10(self.omega / (2 * torch.pi)), self.imZ) 135 | plt.show() 136 | 137 | def __call__(self, _theta): 138 | """ 139 | Args: 140 | - _theta: torch.tensor, circuit parameters, _theta = [rt, r1_, t1, r2_, t2, sigma] 141 | Returns: 142 | - LL: torch.tensor, log-likelihood 143 | """ 144 | theta = torch.squeeze(_theta).detach() 145 | R = torch.exp(-torch.exp(theta[-1])) 146 | self.set_parameters(theta[0], theta[1], theta[2], theta[3], theta[4]) 147 | err_reZ = self.reZ - self.real_part() 148 | err_imZ = self.imZ - self.imarginary_part() 149 | err = (err_reZ @ err_reZ + err_imZ @ err_imZ) 150 | LL = -torch.log(2 * torch.pi * R) * len(self.omega) - 0.5 * err / R 151 | return LL 152 | 153 | 154 | class CanonicalECMOneRCs: 155 | def __init__(self, rt, r1_, t1, sigma, omega): 156 | """ 157 | Rt: the total resistance of the battery (aka R at f=0 [Hz]) 158 | r0: the normalised DC resistance 159 | ri: the normalised AC resistance of the i-th RC pair 160 | ti: the normalised time constant of the i-th RC pair 161 | sigma: experimental noise variance 162 | omega: angular frequency [rad/s] 163 | """ 164 | self.omega = omega 165 | self.noise_sig = torch.tensor(sigma) 166 | self.normalise_freq() 167 | self.set_parameters(rt, r1_, t1) 168 | self.synthetic_data(self.noise_sig) 169 | 170 | def set_parameters(self, rt, r1_, t1): 171 | self.rt = rt 172 | self.t1 = t1 173 | self.r1 = torch.exp(-torch.exp(r1_)) 174 | self.Rt = torch.exp(self.rt) 175 | self.r0 = 1 - self.r1 176 | 177 | def normalise_freq(self): 178 | self.mu = torch.mean(torch.log(self.omega)) 179 | self.sigma = torch.std(torch.log(self.omega)) 180 | 181 | def unnormalise_tau(self, tau): 182 | return torch.exp(-(self.sigma * tau + self.mu)) 183 | 184 | def normalised_input(self, tau): 185 | return torch.log(self.omega) - (self.sigma * tau + self.mu) 186 | 187 | def real_part(self): 188 | return self.Rt * ( 189 | self.r0 + self.r1 / 2 * (1 - torch.tanh(self.normalised_input(self.t1))) 190 | ) 191 | 192 | def imarginary_part(self): 193 | return self.Rt * ( 194 | (self.r1 / 2) / torch.cosh(self.normalised_input(self.t1)) 195 | ) 196 | 197 | def synthetic_data(self, sigma): 198 | R = torch.exp(-torch.exp(sigma)) 199 | self.reZ = self.real_part() + Normal(0, 1).sample(torch.Size([len(self.omega)])) * torch.sqrt(R) 200 | self.imZ = self.imarginary_part() + Normal(0, 1).sample(torch.Size([len(self.omega)])) * torch.sqrt(R) 201 | 202 | def set_true_data(self, reZ, imZ): 203 | self.reZ = reZ 204 | self.imZ = imZ 205 | 206 | def convert_circuit_elements(self): 207 | R0 = self.Rt * self.r0 208 | R1 = self.Rt * self.r1 209 | lnt1 = torch.exp(-(self.sigma * self.t1 + self.mu)) 210 | C1 = lnt1 / R1 211 | return R0, R1, C1 212 | 213 | def plot(self): 214 | # without noise 215 | plt.scatter(self.real_part(), self.imarginary_part()) 216 | plt.show() 217 | plt.scatter(torch.log10(self.omega / (2 * torch.pi)), self.real_part()) 218 | plt.scatter(torch.log10(self.omega / (2 * torch.pi)), self.imarginary_part()) 219 | plt.show() 220 | 221 | # with noise 222 | plt.scatter(self.reZ, self.imZ) 223 | plt.show() 224 | plt.scatter(torch.log10(self.omega / (2 * torch.pi)), self.reZ) 225 | plt.scatter(torch.log10(self.omega / (2 * torch.pi)), self.imZ) 226 | plt.show() 227 | 228 | def __call__(self, _theta): 229 | theta = torch.squeeze(_theta) 230 | R = torch.exp(-torch.exp(theta[-1])) 231 | self.set_parameters(theta[0], theta[1], theta[2]) 232 | err_reZ = self.reZ - self.real_part() 233 | err_imZ = self.imZ - self.imarginary_part() 234 | err = err_reZ @ err_reZ + err_imZ @ err_imZ 235 | LL = -torch.log(2 * torch.pi * R) * len(self.omega) - 0.5 * err / R 236 | return LL 237 | 238 | 239 | class CanonicalECMThreeRCs: 240 | def __init__(self, rt, r1_, t1, r2_, t2, r3_, t3, sigma, omega): 241 | """ 242 | Rt: the total resistance of the battery (aka R at f=0 [Hz]) 243 | r0: the normalised DC resistance 244 | ri: the normalised AC resistance of the i-th RC pair 245 | ti: the normalised time constant of the i-th RC pair 246 | sigma: experimental noise variance 247 | omega: angular frequency [rad/s] 248 | """ 249 | self.omega = omega 250 | self.noise_sig = torch.tensor(sigma) 251 | self.normalise_freq() 252 | self.set_parameters(rt, r1_, t1, r2_, t2, r3_, t3) 253 | self.synthetic_data(self.noise_sig) 254 | 255 | def set_parameters(self, rt, r1_, t1, r2_, t2, r3_, t3): 256 | self.rt = rt 257 | self.t1 = t1 258 | self.r1 = torch.exp(-torch.exp(r1_)) 259 | self.t2 = t2 260 | self.r2 = torch.exp(-torch.exp(r2_)) 261 | self.t3 = t3 262 | self.r3 = torch.exp(-torch.exp(r3_)) 263 | self.Rt = self.rt 264 | self.r0 = 1 - self.r1 - self.r2 - self.r3 265 | 266 | def normalise_freq(self): 267 | self.mu = torch.mean(torch.log(self.omega)) 268 | self.sigma = torch.std(torch.log(self.omega)) 269 | 270 | def unnormalise_tau(self, tau): 271 | return torch.exp(-(self.sigma * tau + self.mu)) 272 | 273 | def normalised_input(self, tau): 274 | return torch.log(self.omega) - (self.sigma * tau + self.mu) 275 | 276 | def real_part(self): 277 | return self.Rt * ( 278 | self.r0 + self.r1 / 2 * (1 - torch.tanh(self.normalised_input(self.t1))) 279 | + self.r2 / 2 * (1 - torch.tanh(self.normalised_input(self.t2))) 280 | + self.r3 / 2 * (1 - torch.tanh(self.normalised_input(self.t3))) 281 | ) 282 | 283 | def imarginary_part(self): 284 | return self.Rt * ( 285 | (self.r1 / 2) / torch.cosh(self.normalised_input(self.t1)) 286 | + (self.r2 / 2) / torch.cosh(self.normalised_input(self.t2)) 287 | + (self.r3 / 2) / torch.cosh(self.normalised_input(self.t3)) 288 | ) 289 | 290 | def synthetic_data(self, sigma): 291 | R = torch.exp(-torch.exp(sigma)) 292 | self.reZ = self.real_part() + Normal(0, 1).sample(torch.Size([len(self.omega)])) * torch.sqrt(R) 293 | self.imZ = self.imarginary_part() + Normal(0, 1).sample(torch.Size([len(self.omega)])) * torch.sqrt(R) 294 | 295 | def set_true_data(self, reZ, imZ): 296 | self.reZ = reZ 297 | self.imZ = imZ 298 | 299 | def convert_circuit_elements(self): 300 | R0 = self.Rt * self.r0 301 | R1 = self.Rt * self.r1 302 | R2 = self.Rt * self.r2 303 | R3 = self.Rt * self.r3 304 | lnt1 = torch.exp(-(self.sigma * self.t1 + self.mu)) 305 | C1 = lnt1 / R1 306 | lnt2 = torch.exp(-(self.sigma * self.t2 + self.mu)) 307 | C2 = lnt2 / R2 308 | lnt3 = torch.exp(-(self.sigma * self.t3 + self.mu)) 309 | C3 = lnt3 / R3 310 | return R0, R1, C1, R2, C2, R3, C3 311 | 312 | def plot(self): 313 | # without noise 314 | plt.scatter(self.real_part(), self.imarginary_part()) 315 | plt.show() 316 | plt.scatter(torch.log10(self.omega / (2 * torch.pi)), self.real_part()) 317 | plt.scatter(torch.log10(self.omega / (2 * torch.pi)), self.imarginary_part()) 318 | plt.show() 319 | 320 | # with noise 321 | plt.scatter(self.reZ, self.imZ) 322 | plt.show() 323 | plt.scatter(torch.log10(self.omega / (2 * torch.pi)), self.reZ) 324 | plt.scatter(torch.log10(self.omega / (2 * torch.pi)), self.imZ) 325 | plt.show() 326 | 327 | def __call__(self, _theta): 328 | theta = torch.squeeze(_theta) 329 | R = torch.exp(-torch.exp(theta[-1])) 330 | self.set_parameters(theta[0], theta[1], theta[2], theta[3], theta[4], theta[5], theta[6]) 331 | err_reZ = self.reZ - self.real_part() 332 | err_imZ = self.imZ - self.imarginary_part() 333 | err = err_reZ @ err_reZ + err_imZ @ err_imZ 334 | LL = -torch.log(2 * torch.pi * R) * len(self.omega) - 0.5 * err / R 335 | return LL 336 | -------------------------------------------------------------------------------- /BASQ/_mmlt_wsabi.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from ._gp import update_gp, predict, predictive_covariance 4 | from ._utils import Utils 5 | from ._gaussian_calc import GaussianCalc 6 | 7 | 8 | class MmltWsabiGP: 9 | def __init__( 10 | self, 11 | Xobs, 12 | Yobs, 13 | gp_kernel, 14 | device, 15 | label="wsabim", 16 | alpha_factor=1, 17 | lik=1e-10, 18 | training_iter=10000, 19 | thresh=0.01, 20 | lr=0.1, 21 | rng=10, 22 | train_lik=False, 23 | optimiser="L-BFGS-B", 24 | ): 25 | """ 26 | MMLT-WSABI BQ modelling 27 | MMLT-WSABI is a doubly warped GP modelling. It consists of three levels of warped space. 28 | The observation y is assumed to belong to g space. (query returns true log likelihoods.) 29 | WSABI modelling permits a factorisation trick in log-space. 30 | ____________________________________________ _______________________________________ __________________ 31 | | f space | g space | h space | 32 | |____________________________________________|_______________________________________|__________________| 33 | | f | g = sqrt(2(f - α)) | h = log(g + 1) | 34 | | f = α + 1/2g^2 | g = exp(h) - 1 | h | 35 | |____________________________________________|_______________________________________|__________________| 36 | | f = GP(μ_f, σ_f) | g = GP(μ_g, σ_g) | h = GP(μ_h, σ_h) | 37 | | μ_f = α + 1/2(μ_g^2 + σ_g) | μ_g = exp(μ_h + 1/2 σ_h) - 1 | | 38 | | σ_f = 1/2σ_g(x,y)^2 + μ_g(x)σ_g(x,y)μ_g(y) | σ_g = μ_g(x)μ_g(y)(exp(σ_h(x,y)) - 1) | | 39 | |____________________________________________|_______________________________________|__________________| 40 | 41 | WsabiMmltGP class summarises the functions of training, updating the warped GP model. 42 | This also provides the prediction and kernel of WSABI GP. 43 | The modelling of WSABI-L and WSABI-M can be easily switched by changing "label". 44 | The above table is the case of WSABI-M. 45 | 46 | Args: 47 | - Xobs: torch.tensor, X samples, X belongs to prior measure. 48 | - Yobs: torch.tensor, Y observations, Y = true_likelihood(X). 49 | - gp_kernel: gpytorch.kernels, GP kernel function 50 | - device: torch.device, device, cpu or cuda 51 | - label: string, the wsabi type, ["wsabil", "wsabim"] 52 | - lik: float, the initial value of GP likelihood noise variance 53 | - train_iter: int, the maximum iteration for GP hyperparameter training. 54 | - thresh: float, the threshold as a stopping criterion of GP hyperparameter training. 55 | - lr: float, the learning rate of Adam optimiser 56 | - rng: int, tne range coefficient of GP likelihood noise variance 57 | - train_like: bool, flag whether or not to update GP likelihood noise variance 58 | - optimiser: string, select the optimiser ["L-BFGS-B", "Adam"] 59 | """ 60 | self.gp_kernel = gp_kernel 61 | self.device = device 62 | self.alpha_factor = 1 63 | self.alpha = alpha_factor 64 | self.lik = lik 65 | self.training_iter = training_iter 66 | self.thresh = thresh 67 | self.lr = lr 68 | self.rng = rng 69 | self.train_lik = train_lik 70 | self.optimiser = optimiser 71 | 72 | self.jitter = 0 # 1e-6 73 | self.Y_log = copy.deepcopy(Yobs) 74 | self.utils = Utils(device) 75 | 76 | self.model = update_gp( 77 | Xobs, 78 | self.process_y_warping(Yobs), 79 | gp_kernel, 80 | self.device, 81 | lik=self.lik, 82 | training_iter=self.training_iter, 83 | thresh=self.thresh, 84 | lr=self.lr, 85 | rng=self.rng, 86 | train_lik=self.train_lik, 87 | optimiser=self.optimiser, 88 | ) 89 | self.gauss = GaussianCalc(self.model, self.device) 90 | 91 | def warp_from_f_to_g(self, y_f): 92 | """ 93 | Args: 94 | - y_f: torch.tensor, observations in f space 95 | 96 | Returns: 97 | - y_g: torch.tensor, warped observations in g space 98 | """ 99 | y_g = torch.sqrt(2 * (y_f - self.alpha)) 100 | return y_g 101 | 102 | def warp_from_g_to_h(self, y_g): 103 | """ 104 | Args: 105 | - y_f: torch.tensor, warped observations in g space 106 | 107 | Returns: 108 | - y_g: torch.tensor, warped observations in h space 109 | """ 110 | y_h = torch.log(y_g + 1) 111 | return y_h 112 | 113 | def warp_from_f_to_h(self, y_f): 114 | """ 115 | Args: 116 | - y_f: torch.tensor, warped observations in g space 117 | 118 | Returns: 119 | - y_h: torch.tensor, warped observations in h space 120 | """ 121 | y_g = self.warp_from_f_to_g(y_f) 122 | y_h = self.warp_from_g_to_h(y_g) 123 | return y_h 124 | 125 | def unwarp_from_h_to_g(self, y_h): 126 | """ 127 | Args: 128 | - y_f: torch.tensor, warped observations in h space 129 | 130 | Returns: 131 | - y_g: torch.tensor, warped observations in g space 132 | """ 133 | y_g = torch.exp(y_h) - 1 134 | return y_g 135 | 136 | def unwarp_from_g_to_f(self, y_g): 137 | """ 138 | Args: 139 | - y_f: torch.tensor, warped observations in g space 140 | 141 | Returns: 142 | - y_g: torch.tensor, observations in f space 143 | """ 144 | y_f = self.alpha + 0.5 * (y_g**2) 145 | return y_f 146 | 147 | def process_y_warping(self, y_obs): 148 | """ 149 | Args: 150 | - y_obs: torch.tensor, observations of true_loglikelihood 151 | 152 | Returns: 153 | - y_h: torch.tensor, warped observations in h space that contains no anomalies and the updated alpha hyperparameter. 154 | """ 155 | 156 | y = self.utils.remove_anomalies(y_obs) 157 | y_f = torch.exp(y) 158 | self.alpha = self.alpha_factor * torch.min(y_f) 159 | y_h = self.warp_from_f_to_h(y_f) 160 | return y_h 161 | 162 | def cat_observations(self, X, Y): 163 | """ 164 | Args: 165 | - X: torch.tensor, X samples to be added to the existing data Xobs 166 | - Y: torch.tensor, unwarped Y observations to be added to the existing data Yobs 167 | 168 | Returns: 169 | - Xall: torch.tensor, X samples that contains all samples 170 | - Yall: torch.tensor, warped Y observations that contains all observations 171 | """ 172 | Xobs = self.model.train_inputs[0] 173 | Yobs_log = copy.deepcopy(self.Y_log) 174 | if len(self.model.train_targets.shape) == 0: 175 | Yobs_log = Yobs_log.unsqueeze(0) 176 | Xall = torch.cat([Xobs, X]) 177 | Yall_log = torch.cat([Yobs_log, Y]) 178 | self.Y_log = copy.deepcopy(Yall_log) 179 | Yall_h = self.process_y_warping(Yall_log) 180 | return Xall, Yall_h 181 | 182 | def update_mmlt_gp(self, X, Y): 183 | """ 184 | Args: 185 | - X: torch.tensor, X samples to be added to the existing data Xobs 186 | - Y: torch.tensor, unwarped Y observations to be added to the existing data Yobs 187 | """ 188 | X_h, Y_h = self.cat_observations(X, Y) 189 | self.model = update_gp( 190 | X_h, 191 | Y_h, 192 | self.gp_kernel, 193 | self.device, 194 | lik=self.lik, 195 | training_iter=self.training_iter, 196 | thresh=self.thresh, 197 | lr=self.lr, 198 | rng=self.rng, 199 | train_lik=self.train_lik, 200 | optimiser=self.optimiser, 201 | ) 202 | 203 | def retrain_gp(self): 204 | X_h = self.model.train_inputs[0] 205 | Y_h = self.process_y_warping(copy.deepcopy(self.Y_log)) 206 | self.model = update_gp( 207 | X_h, 208 | Y_h, 209 | self.gp_kernel, 210 | self.device, 211 | lik=self.lik, 212 | training_iter=self.training_iter, 213 | thresh=self.thresh, 214 | lr=self.lr, 215 | rng=self.rng, 216 | train_lik=self.train_lik, 217 | optimiser=self.optimiser, 218 | ) 219 | 220 | def memorise_parameters(self): 221 | self.likelihood_memory = copy.deepcopy(torch.tensor(self.model.likelihood.noise.item())) 222 | self.outputsacle_memory = copy.deepcopy(torch.tensor(self.model.covar_module.outputscale.item())) 223 | self.lengthscale_memory = copy.deepcopy(torch.tensor(self.model.covar_module.base_kernel.lengthscale.item())) 224 | 225 | def remind_parameters(self): 226 | hypers = { 227 | 'likelihood.noise_covar.noise': self.likelihood_memory, 228 | 'covar_module.outputscale': self.outputsacle_memory, 229 | 'covar_module.base_kernel.lengthscale': self.lengthscale_memory, 230 | } 231 | self.model.initialize(**hypers) 232 | 233 | def hspace_predict(self, x): 234 | """ 235 | Args: 236 | - x: torch.tensor, x locations to be predicted 237 | 238 | Returns: 239 | - mu_h: torch.tensor, unwarped predictive mean in h space at given locations x. 240 | - var_h: torch.tensor, unwarped predictive variance in h space at given locations x. 241 | """ 242 | mu_h, var_h = predict(x, self.model) 243 | return mu_h, var_h 244 | 245 | def gspace_predict(self, x): 246 | """ 247 | Args: 248 | - x: torch.tensor, x locations to be predicted 249 | 250 | Returns: 251 | - mu_g: torch.tensor, unwarped predictive mean in g space at given locations x. 252 | - var_g: torch.tensor, unwarped predictive variance in g space at given locations x. 253 | """ 254 | mu_h, var_h = self.hspace_predict(x) 255 | mu_g = (mu_h + 0.5 * var_h).exp() - 1 256 | var_g = (mu_g ** 2) * (var_h.exp() - 1) 257 | return mu_g, var_g 258 | 259 | def fspace_predict(self, x): 260 | """ 261 | Args: 262 | - x: torch.tensor, x locations to be predicted 263 | 264 | Returns: 265 | - mu_f: torch.tensor, unwarped predictive mean in f space at given locations x. 266 | - var_f: torch.tensor, unwarped predictive variance in f space at given locations x. 267 | """ 268 | mu_g, var_g = self.gspace_predict(x) 269 | mu_f = self.alpha + 0.5 * (mu_g**2 + var_g) 270 | var_f = var_g * mu_g * var_g + 0.5 * (var_g ** 2) 271 | return mu_f, var_f 272 | 273 | def hspace_mean_predict(self, x): 274 | """ 275 | Args: 276 | - x: torch.tensor, x locations to be predicted 277 | 278 | Returns: 279 | - mu_h: torch.tensor, unwarped predictive mean in h space at given locations x. 280 | """ 281 | mu_h, _ = self.hspace_predict(x) 282 | return mu_h 283 | 284 | def gspace_mean_predict(self, x): 285 | """ 286 | Args: 287 | - x: torch.tensor, x locations to be predicted 288 | 289 | Returns: 290 | - mu_g: torch.tensor, unwarped predictive mean in g space at given locations x. 291 | """ 292 | mu_g, _ = self.gspace_predict(x) 293 | return mu_g 294 | 295 | def fspace_mean_predict(self, x): 296 | """ 297 | Args: 298 | - x: torch.tensor, x locations to be predicted 299 | 300 | Returns: 301 | - mu_f: torch.tensor, unwarped predictive mean in f space at given locations x. 302 | """ 303 | mu_f, _ = self.fspace_predict(x) 304 | return mu_f 305 | 306 | def hspace_kernel(self, x, y): 307 | """ 308 | Args: 309 | - x: torch.tensor, x locations to be predicted 310 | - y: torch.tensor, y locations to be predicted 311 | 312 | Args: 313 | - CLy: torch.tensor, the positive semi-definite Gram matrix of predictive variance in hscape 314 | """ 315 | return predictive_covariance(x, y, self.model) 316 | 317 | def gspace_kernel(self, x, y): 318 | """ 319 | Args: 320 | - x: torch.tensor, x locations to be predicted 321 | - y: torch.tensor, y locations to be predicted 322 | 323 | Returns: 324 | - CLy: torch.tensor, the positive semi-definite Gram matrix of predictive variance in gscape 325 | """ 326 | mu_g_x = self.gspace_mean_predict(x) 327 | mu_g_y = self.gspace_mean_predict(y) 328 | cov_h_xy = self.hspace_kernel(x, y) 329 | CLy = mu_g_x.unsqueeze(1) * mu_g_y.unsqueeze(0) * (cov_h_xy.exp() - 1) 330 | 331 | d = min(len(x), len(y)) 332 | CLy[range(d), range(d)] = CLy[range(d), range(d)] + self.jitter 333 | return CLy 334 | 335 | def fspace_kernel(self, x, y): 336 | """ 337 | Args: 338 | - x: torch.tensor, x locations to be predicted 339 | - y: torch.tensor, y locations to be predicted 340 | 341 | Returns: 342 | - CLy: torch.tensor, the positive semi-definite Gram matrix of predictive variance in gscape 343 | """ 344 | mu_g_x = self.gspace_mean_predict(x) 345 | mu_g_y = self.gspace_mean_predict(y) 346 | cov_g_xy = self.gspace_kernel(x, y) 347 | CLy = mu_g_x.unsqueeze(1) * cov_g_xy * mu_g_y.unsqueeze(0) + 0.5 * (cov_g_xy ** 2) 348 | 349 | d = min(len(x), len(y)) 350 | CLy[range(d), range(d)] = CLy[range(d), range(d)] + self.jitter 351 | return CLy 352 | -------------------------------------------------------------------------------- /BASQ/_parameters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gpytorch 3 | from ._quadrature import KernelQuadrature, ScaleKernelQuadrature 4 | from ._sampler import PriorSampler, UncertaintySampler, LogUncertaintySampler 5 | from ._wsabi import WsabiGP 6 | from ._vbq import VanillaGP 7 | from ._scale_mmlt_wsabi import ScaleMmltWsabiGP 8 | from ._posterior import Posterior 9 | 10 | 11 | def parameter_set(): 12 | # BQ Modelling 13 | params = { 14 | "bq_model": "mmlt", # select a BQ model from ["mmlt", "wsabi", "vbq"], vbq stands for Vanilla BQ 15 | "sampler_type": "uncertainty", # select a sampler from ["uncertainty", "prior"] 16 | "kernel_type": "RBF", # select a kernel from ["RBF", "Matern32", "Matern52"] 17 | 18 | # WSABI modelling 19 | "wsabi_type": "wsabim", # select a wsabi type from ["wsabil", "wsabim"] 20 | "alpha_factor": 1, # coefficient of alpha in WSABI modelling; alpha = 0.8 * min(y) 21 | 22 | # GP hyperparameter training with type-II MLE 23 | "optimiser": "BoTorch", # select a optimiser from ["L-BFGS-B", "BoTorch", "Adam"], BoTorch is the slowest but the most accurate 24 | "lik": 1e-10, # centre value of GP likelihood noise. 25 | "rng": 10, # range of likelihood noise [lik/rng, lik*rng] 26 | "train_lik": False, # flag whether or not to train likelihood noise. if False, the noise is fixed with lik 27 | "training_iter": 1000, # maximum number of SDG interations 28 | "thresh": 0.05, # stopping criterion. threshold = last_MLL - current_MLL 29 | "lr": 0.1, # learning rate of Adam 30 | 31 | # RCHQ hyperparameters 32 | "n_rec": 20000, # subsampling size for kernel recombination 33 | "nys_ratio": 1e-2, # subsubsampling ratio for Nystrom. Number of Nystrom samples is nys_ratio * n_rec 34 | "batch_size": 100, # batch size 35 | "quad_ratio": 5, # supersampling ratio for quadrature. Number of recombination samples is quad_ratio * n_rec 36 | 37 | # Uncertainty sampling 38 | "ratio": 1, # mixing ratio of prior and uncertainty sampling, 0 < r < 1 39 | "n_gaussians": 50, # number of Gaussians approximating the GP-modelled acquisition function 40 | "threshold": 0, # threshold to cut off the insignificant Gaussians 41 | "sampling_method": "exact", # select sampling method from ["exact", "approx"] 42 | 43 | # Utility 44 | "show_progress": True, # flag whether or not show the quadrature result over each iteration. 45 | } 46 | return params 47 | 48 | class Parameters: 49 | def __init__(self, Xobs, Yobs, prior, true_likelihood, device): 50 | """ 51 | Args: 52 | - Xobs; torch.tensor, X samples, X belongs to prior measure. 53 | - Yobs; torch.tensor, Y observations, Y = true_likelihood(X). 54 | - prior; torch.distributions, prior distribution. 55 | - true_likelihood; function of y = function(x), true likelihood to be estimated. 56 | - device; torch.device, device, cpu or cuda 57 | """ 58 | params = parameter_set() 59 | self.load_setting(params, Xobs, Yobs, prior, true_likelihood, device) 60 | 61 | def load_setting(self, params, Xobs, Yobs, prior, true_likelihood, device): 62 | """ 63 | Args: 64 | - params; dict, the dictionary of the parameters 65 | - Xobs; torch.tensor, X samples, X belongs to prior measure. 66 | - Yobs; torch.tensor, Y observations, Y = true_likelihood(X). 67 | - prior; torch.distributions, prior distribution. 68 | - true_likelihood; function of y = function(x), true likelihood to be estimated. 69 | - device; torch.device, device, cpu or cuda 70 | """ 71 | self.n_rec = params["n_rec"] 72 | self.batch_size = params["batch_size"] 73 | self.device = device 74 | self.show_progress = params["show_progress"] 75 | 76 | self.prior = prior 77 | self.true_likelihood = true_likelihood 78 | 79 | self.bq_model = params["bq_model"] 80 | self.sampler_type = params["sampler_type"] 81 | self.check_compatibility(self.bq_model, self.sampler_type, params["kernel_type"]) 82 | gp_kernel = self.set_kernel(params["kernel_type"]) 83 | self.set_model( 84 | self.bq_model, 85 | Xobs, 86 | Yobs, 87 | gp_kernel, 88 | params["wsabi_type"], 89 | params["alpha_factor"], 90 | params["lik"], 91 | params["training_iter"], 92 | params["thresh"], 93 | params["lr"], 94 | params["rng"], 95 | params["train_lik"], 96 | params["optimiser"], 97 | ) 98 | self.set_sampler( 99 | self.sampler_type, 100 | params["sampling_method"], 101 | self.bq_model, 102 | prior, 103 | self.n_rec, 104 | params["nys_ratio"], 105 | params["ratio"], 106 | params["n_gaussians"], 107 | params["threshold"], 108 | ) 109 | self.set_quadrature( 110 | params["nys_ratio"], 111 | int(self.n_rec * params["nys_ratio"]), 112 | int(self.n_rec * params["quad_ratio"]), 113 | self.bq_model, 114 | ) 115 | self.set_posterior() 116 | self.verbose( 117 | self.bq_model, 118 | self.sampler_type, 119 | params["kernel_type"], 120 | params["sampling_method"], 121 | params["optimiser"], 122 | ) 123 | 124 | def verbose(self, bq_model, sampler_type, kernel_type, sampling_method, optimiser): 125 | print( 126 | "BQ model: " + bq_model 127 | + " | kernel: " + kernel_type 128 | + " | sampler: " + sampler_type 129 | + " | sampling_method: " + sampling_method 130 | + " | optimiser: " + optimiser 131 | ) 132 | 133 | def set_sampler(self, sampler_type, sampling_method, bq_model, prior, n_rec, nys_ratio, ratio, n_gaussians, threshold): 134 | """ 135 | Args: 136 | - sampler_type; string, ["uncertainty", "prior"] 137 | - sampling_method; string, ["exact", "approx"] 138 | - bq_model: string, ["scale_mmlt", "mmlt", "wsabi", "vbq"] 139 | - prior: torch.distributions, prior distribution. 140 | - n_rec: int, subsampling size for kernel recombination 141 | - nys_ratio: float, subsubsampling ratio for Nystrom. Number of Nystrom samples is nys_ratio * n_rec 142 | - ratio: float, mixing ratio of prior and uncertainty sampling, 0 < r < 1 143 | - n_gaussians: int, number of Gaussians approximating the GP-modelled acquisition function 144 | - threshold: float, threshold to cut off the insignificant Gaussians 145 | """ 146 | self.prior_sampler = PriorSampler(prior, n_rec, nys_ratio, self.device) 147 | 148 | if sampler_type == "uncertainty": 149 | if bq_model == "wsabi": 150 | self.sampler = UncertaintySampler( 151 | prior, 152 | self.gp.model, 153 | n_rec, 154 | nys_ratio, 155 | self.device, 156 | sampling_method=sampling_method, 157 | ratio=ratio, 158 | n_gaussians=n_gaussians, 159 | threshold=threshold, 160 | ) 161 | self.kernel = self.gp.predictive_kernel 162 | self.kernel_quadrature = self.gp.kernel 163 | elif bq_model == "mmlt": 164 | self.sampler = LogUncertaintySampler( 165 | prior, 166 | self.gp, 167 | n_rec, 168 | nys_ratio, 169 | self.device, 170 | sampling_method=sampling_method, 171 | ratio=ratio, 172 | n_gaussians=n_gaussians, 173 | threshold=threshold, 174 | ) 175 | self.kernel = self.gp.gspace_kernel 176 | self.kernel_quadrature = self.gp.fspace_kernel 177 | else: 178 | raise Exception("The given bq_model is not compatible with the uncertainty sampler.") 179 | elif sampler_type == "prior": 180 | self.sampler = self.prior_sampler 181 | self.kernel_quadrature = self.kernel 182 | else: 183 | raise Exception("The given sampler_type is undefined.") 184 | 185 | def set_model(self, bq_model, Xobs, Yobs, gp_kernel, wsabi_type, alpha_factor, lik, training_iter, thresh, lr, rng, train_lik, optimiser): 186 | """ 187 | Args: 188 | - bq_model: string, ["mmlt", "wsabi", "vbq"] 189 | - Xobs: torch.tensor, X samples, X belongs to prior measure. 190 | - Yobs: torch.tensor, Y observations, Y = true_likelihood(X). 191 | - gp_kernel: gpytorch.kernels, GP kernel function 192 | - wsabi_type: string, ["wsabil", "wsabim"] 193 | - alpha_factor: float, coefficient of alpha in WSABI modelling; alpha = 0.8 * min(y) 194 | - lik: float, the initial value of GP likelihood noise variance 195 | - train_iter: int, the maximum iteration for GP hyperparameter training. 196 | - thresh: float, the threshold as a stopping criterion of GP hyperparameter training. 197 | - lr: float, the learning rate of Adam optimiser 198 | - rng: int, tne range coefficient of GP likelihood noise variance 199 | - train_like: bool, flag whether or not to update GP likelihood noise variance 200 | - optimiser: string, select the optimiser ["L-BFGS-B", "BoTorch", "Adam"] 201 | """ 202 | if bq_model == "wsabi": 203 | self.gp = WsabiGP( 204 | Xobs, 205 | Yobs, 206 | gp_kernel, 207 | self.device, 208 | label=wsabi_type, 209 | alpha_factor=alpha_factor, 210 | lik=lik, 211 | training_iter=training_iter, 212 | thresh=thresh, 213 | lr=lr, 214 | rng=rng, 215 | train_lik=train_lik, 216 | optimiser=optimiser, 217 | ) 218 | self.kernel = self.gp.kernel 219 | self.predict_mean = self.gp.predict_mean 220 | self.predict = self.gp.predict 221 | self.retrain = self.gp.retrain_gp 222 | self.update = self.gp.update_wsabi_gp 223 | self.unimodal_approx = self.gp.unimodal_approximation 224 | self.uniform_trans = self.gp.uniform_transformation 225 | 226 | elif bq_model == "mmlt": 227 | self.gp = ScaleMmltWsabiGP( 228 | Xobs, 229 | Yobs, 230 | gp_kernel, 231 | self.device, 232 | label="wsabim", 233 | alpha_factor=1, 234 | lik=lik, 235 | training_iter=training_iter, 236 | thresh=thresh, 237 | lr=lr, 238 | rng=rng, 239 | train_lik=train_lik, 240 | optimiser=optimiser, 241 | ) 242 | self.predict_mean = self.gp.fspace_mean_predict 243 | self.predict = self.gp.fspace_predict 244 | self.retrain = self.gp.retrain_gp_with_scaling 245 | self.kernel = self.gp.fspace_kernel 246 | self.update = self.gp.update_mmlt_gp_with_scaling 247 | 248 | elif bq_model == "vbq": 249 | self.gp = VanillaGP( 250 | Xobs, 251 | Yobs, 252 | gp_kernel, 253 | self.device, 254 | lik=lik, 255 | training_iter=training_iter, 256 | thresh=thresh, 257 | lr=lr, 258 | rng=rng, 259 | train_lik=train_lik, 260 | optimiser=optimiser, 261 | ) 262 | self.kernel = self.gp.predictive_kernel 263 | self.update = self.gp.update_gp 264 | self.retrain = self.gp.retrain_gp 265 | self.predict_mean = self.gp.predict_mean 266 | self.predict = self.gp.predict 267 | self.retrain = self.gp.retrain_gp 268 | else: 269 | raise Exception("The given bq_model is undefined.") 270 | 271 | def set_quadrature(self, nys_ratio, n_nys, n_quad, bq_model): 272 | """ 273 | Args: 274 | - nys_ratio: float, subsubsampling ratio for Nystrom. 275 | - n_nys: int, number of Nystrom samples; int(nys_ratio * n_rec) 276 | - n_quad: int, number of kernel recombination subsamples; int(quad_ratio * n_rec) 277 | - bq_model: string, ["mmlt", "wsabi", "vbq"] 278 | """ 279 | if bq_model == "mmlt": 280 | self.kq = ScaleKernelQuadrature( 281 | self.n_rec, 282 | n_nys, 283 | n_quad, 284 | self.batch_size, 285 | self.prior_sampler, 286 | self.gp, 287 | self.device, 288 | ) 289 | else: 290 | self.kq = KernelQuadrature( 291 | self.n_rec, 292 | n_nys, 293 | n_quad, 294 | self.batch_size, 295 | self.prior_sampler, 296 | self.kernel_quadrature, 297 | self.device, 298 | self.predict_mean, 299 | ) 300 | 301 | def set_posterior(self): 302 | self.posterior = Posterior( 303 | self.bq_model, 304 | self.prior, 305 | self.gp, 306 | self.predict_mean, 307 | self.kq, 308 | self.sampler, 309 | self.sampler_type, 310 | ) 311 | 312 | def set_kernel(self, kernel_type): 313 | """ 314 | Args: 315 | - kernel_type: string, ["RBF", "Matern32", "Matern52"] 316 | 317 | Returns: 318 | - gp_kernel: gpytorch.kernels, function of GP kernel 319 | """ 320 | if kernel_type == "RBF": 321 | gp_kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) 322 | elif kernel_type == "Matern32": 323 | gp_kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=1.5)) 324 | elif kernel_type == "Matern52": 325 | gp_kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=2.5)) 326 | else: 327 | raise Exception("The given kernel_type is undefined.") 328 | return gp_kernel 329 | 330 | def check_compatibility(self, bq_model, sampler_type, kernel_type): 331 | """ 332 | Args: 333 | - bq_model: string, ["mmlt", "wsabi", "vbq"] 334 | - sampler_type; string, ["uncertainty", "prior"] 335 | - kernel_type: string, ["RBF", "Matern32", "Matern52"] 336 | """ 337 | if bq_model == "wsabi": 338 | if not kernel_type == "RBF": 339 | raise AssertionError("WSABI model requires RBF kernel.") 340 | else: 341 | if not sampler_type == "prior" and not bq_model == "mmlt": 342 | raise AssertionError("Uncertainty sampling requires WSABI-L modelling with RBF kernel.") 343 | 344 | if sampler_type == "uncertainty": 345 | if not kernel_type == "RBF" or bq_model == "vbq": 346 | raise AssertionError("Uncertainty sampling requires RBF kernel.") 347 | 348 | if not type(self.prior) == torch.distributions.multivariate_normal.MultivariateNormal: 349 | if not bq_model == "vbq" and sampler_type == "prior": 350 | raise AssertionError("Non-Gaussian prior requires prior sampling with VBQ modelling.") 351 | 352 | if not kernel_type == "RBF": 353 | if not bq_model == "vbq" and sampler_type == "prior": 354 | raise AssertionError("Non-RBF kernel requires prior sampling with VBQ modelling.") 355 | 356 | 357 | -------------------------------------------------------------------------------- /BASQ/_acquisition_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ._gaussian_calc import GaussianCalc 3 | from torch.distributions.multivariate_normal import MultivariateNormal 4 | 5 | 6 | class SquareRootAcquisitionFunction(GaussianCalc): 7 | def __init__(self, prior, model, device, n_gaussians=100, threshold=1e-5): 8 | """ 9 | Inherited the functions from GaussianCalc. 10 | 11 | Args: 12 | - prior: torch.distributions, prior distribution 13 | - device: torch.device, cpu or cuda 14 | """ 15 | super().__init__(prior, device) 16 | self.n_gaussians = n_gaussians # number of Gaussians for uncertainty sampling 17 | self.threshold = threshold # threshold to cut off the small weights 18 | self.update(model) 19 | 20 | def update(self, model): 21 | """ 22 | Args: 23 | - model: gpytorch.models, function of GP model, typically model = self.gp.model in _basq.py 24 | """ 25 | self.parameters_extraction(model) 26 | self.wA, self.wAA, self.mu_AA, self.sigma_AA = self.sparseGMM() 27 | self.d_AA = len(self.mu_AA) 28 | self.w_mean, self.mu_mean, self.sig_mean = self.sparseGMM_mean() 29 | self.d_mean = len(self.mu_mean) 30 | 31 | def sparseGMM(self): 32 | """ 33 | See details on factorisation trick and sparse GMM sampler in Sapplementary. 34 | https://arxiv.org/abs/2206.04734 35 | 36 | Returns: 37 | - w1: torch.tensor, the weight of prior distribution 38 | - w2: torch.tensor, the weights of other normal distributions 39 | - mu2: torch.tensor, the mean vectors of other normal distributions 40 | - sigma2: torch.tensor, the covariance matrix of other normal distributions 41 | """ 42 | i, j = torch.where(self.woodbury_inv < 0) 43 | _w1_ = self.outputscale 44 | _w2_ = torch.abs((self.v**2) * self.woodbury_inv[i, j]) 45 | _Z = _w1_ + torch.sum(_w2_) 46 | _w1, _w2 = _w1_ / _Z, _w2_ / _Z 47 | 48 | Winv = self.W.inverse() 49 | Sinv = self.prior.covariance_matrix.inverse() 50 | sigma2 = (2 * Winv + Sinv).inverse() 51 | 52 | _idx = _w2.argsort(descending=True)[:self.n_gaussians] 53 | idx = _idx[_w2[_idx] > self.threshold] 54 | Xi = self.Xobs[i[idx]] 55 | Xj = self.Xobs[j[idx]] 56 | 57 | w2 = _w2[idx] 58 | mu2 = (sigma2 @ Winv @ (Xi + Xj).T).T + sigma2 @ Sinv @ self.prior.loc 59 | 60 | zA = _w1 + torch.sum(w2) 61 | w1, w2 = _w1 / zA, w2 / zA 62 | return w1, w2, mu2, sigma2 63 | 64 | def joint_pdf(self, x): 65 | """ 66 | Args: 67 | - x: torch.tensor, inputs. torch.Size(n_data, n_dims) 68 | 69 | Returns: 70 | - first/first+second: torch.tensor, the values of probability density function of approximated A(x) 71 | """ 72 | d_x = len(x) 73 | 74 | # calculate the first term 75 | Npdfs_A = self.utils.safe_mvn_prob( 76 | self.prior.loc, 77 | self.prior.covariance_matrix, 78 | x, 79 | ) 80 | first = self.wA * Npdfs_A 81 | 82 | # calulate the second term 83 | if len(self.wAA) == 0: 84 | return first 85 | else: 86 | x_AA = (torch.tile(self.mu_AA, (d_x, 1, 1)) - x.unsqueeze(1)).reshape( 87 | self.d_AA * d_x, self.n_dims 88 | ) 89 | Npdfs_AA = self.utils.safe_mvn_prob( 90 | torch.zeros(self.n_dims).to(self.device), 91 | self.sigma_AA, 92 | x_AA, 93 | ).reshape(d_x, self.d_AA) 94 | 95 | f_AA = self.wAA.unsqueeze(0) * Npdfs_AA 96 | second = f_AA.sum(axis=1) 97 | return first + second 98 | 99 | def sampling(self, n): 100 | """ 101 | Args: 102 | - n: int, number of samples to be sampled. 103 | 104 | Returns: 105 | - samplesA/samplesAA: torch.tensor, the samples from approximated A(x) 106 | """ 107 | cntA = (n * self.wA).type(torch.int) 108 | samplesA = self.prior.sample(torch.Size([cntA])).to(self.device) 109 | 110 | if len(self.wAA) == 0: 111 | return samplesA 112 | else: 113 | cntAA = (n * self.wAA).type(torch.int) 114 | samplesAA = torch.cat([ 115 | MultivariateNormal( 116 | self.mu_AA[i], 117 | self.sigma_AA, 118 | ).sample(torch.Size([cnt])).to(self.device) 119 | for i, cnt in enumerate(cntAA) 120 | ]) 121 | return torch.cat([samplesA, samplesAA]) 122 | 123 | def sparseGMM_mean(self): 124 | """ 125 | Returns: 126 | - weights: torch.tensor, the weight of approximated GP mean functions 127 | - mu_mean: torch.tensor, the mean vectors of approximated GP mean functions 128 | - sig_prime: torch.tensor, the covariance matrix of approximated GP mean functions 129 | """ 130 | Winv = self.W.inverse() 131 | Sinv = self.prior.covariance_matrix.inverse() 132 | sig_prime = (Winv + Sinv).inverse() 133 | mu_prime = (sig_prime @ ( 134 | (Winv @ self.Xobs.T).T + Sinv @ self.prior.loc 135 | ).T).T 136 | npdfs = MultivariateNormal( 137 | self.prior.loc, 138 | self.W + self.prior.covariance_matrix, 139 | ).log_prob(self.Xobs).exp() 140 | omega_prime = self.woodbury_vector * npdfs 141 | 142 | weights = omega_prime / omega_prime.sum() 143 | W_prime = weights * MultivariateNormal( 144 | self.prior.loc, 145 | sig_prime, 146 | ).log_prob(mu_prime).exp() 147 | 148 | W_pos = W_prime[W_prime > 0].sum() 149 | W_neg = W_prime[W_prime < 0].sum().abs() 150 | N_pos = int(W_pos / (W_pos + W_neg) * self.n_gaussians) 151 | N_neg = self.n_gaussians - N_pos 152 | idx_pos = W_prime[W_prime > 0].argsort(descending=True)[:N_pos] 153 | idx_neg = W_prime[W_prime < 0].argsort()[:N_neg] 154 | weights_pos = weights[W_prime > 0][idx_pos] 155 | weights_neg = weights[W_prime < 0][idx_neg].abs() 156 | weights = torch.cat([weights_pos, weights_neg]) 157 | mu_pos = mu_prime[W_prime > 0][idx_pos] 158 | mu_neg = mu_prime[W_prime < 0][idx_neg] 159 | mu_mean = torch.cat([mu_pos, mu_neg]) 160 | 161 | idx_weights = weights > (self.threshold * weights.sum()) 162 | weights = weights[idx_weights] 163 | mu_mean = mu_mean[idx_weights] 164 | weights = weights / weights.sum() 165 | return weights, mu_mean, sig_prime 166 | 167 | def joint_pdf_mean(self, x): 168 | """ 169 | Args: 170 | - x: torch.tensor, inputs. torch.Size(n_data, n_dims) 171 | 172 | Returns: 173 | - first/first+second: torch.tensor, the values of probability density function of approximated GP mean functions 174 | """ 175 | d_x = len(x) 176 | 177 | x_AA = (torch.tile(self.mu_mean, (d_x, 1, 1)) - x.unsqueeze(1)).reshape( 178 | self.d_mean * d_x, self.n_dims 179 | ) 180 | Npdfs_AA = self.utils.safe_mvn_prob( 181 | torch.zeros(self.n_dims).to(self.device), 182 | self.sig_mean, 183 | x_AA, 184 | ).reshape(d_x, self.d_mean) 185 | 186 | f_AA = self.w_mean.unsqueeze(0) * Npdfs_AA 187 | pdf = f_AA.sum(axis=1) 188 | return pdf 189 | 190 | def sampling_mean(self, n): 191 | """ 192 | Args: 193 | - n: int, number of samples to be sampled. 194 | 195 | Returns: 196 | - samples: torch.tensor, the samples from approximated GP mean functions 197 | """ 198 | cnts = (n * self.w_mean).type(torch.int) 199 | samples = torch.cat([ 200 | MultivariateNormal( 201 | self.mu_mean[i], 202 | self.sig_mean, 203 | ).sample(torch.Size([cnt])).to(self.device) 204 | for i, cnt in enumerate(cnts) 205 | ]) 206 | return samples 207 | 208 | 209 | class LogRootAcquisitionFunction(GaussianCalc): 210 | def __init__(self, prior, model, device, n_gaussians=100, threshold=1e-5): 211 | """ 212 | Inherited the functions from GaussianCalc. 213 | 214 | Args: 215 | - prior: torch.distributions, prior distribution 216 | - device: torch.device, cpu or cuda 217 | """ 218 | super().__init__(prior, device) 219 | self.n_gaussians = n_gaussians # number of Gaussians for uncertainty sampling 220 | self.threshold = threshold # threshold to cut off the small weights 221 | self.update(model) 222 | 223 | def update(self, model): 224 | """ 225 | Args: 226 | - model: gpytorch.models, function of GP model, typically model = self.gp.model in _basq.py 227 | """ 228 | self.parameters_extraction(model) 229 | self.w_A, self.w_B, self.w_C, self.mu_A, self.mu_B, self.mu_C, self.cov_A, self.cov_B, self.cov_C = self.sparseGMM_mean() 230 | self.d_A, self.d_B, self.d_C = len(self.mu_A), len(self.mu_B), len(self.mu_C) 231 | self.w_cov, self.mu_cov, self.sig_cov = self.sparseGMM() 232 | self.d_cov = len(self.mu_cov) 233 | 234 | def sparseGMM_mean(self): 235 | """ 236 | Returns: 237 | - w_A, w_B, w_C: torch.tensor, the weights of first, second, and third terms of Gaussians 238 | - mu_A, mu_B, mu_C: torch.tensor, the mean vectors of first, second, and third terms of Gaussians 239 | - cov_A, cov_B, cov_C: torch.tensor, the covariance matrix of first, second, and third terms of Gaussians 240 | """ 241 | X_ij_minus = (self.Xobs.unsqueeze(1) - self.Xobs.unsqueeze(0)).reshape(self.n_data**2, self.n_dims) 242 | Npdfs = self.utils.safe_mvn_prob( 243 | torch.zeros(self.n_dims).to(self.device), 244 | 2 * self.W, 245 | X_ij_minus, 246 | ).reshape(self.n_data, self.n_data) 247 | 248 | w = self.v * self.woodbury_vector 249 | w_prime = 0.5 * (self.v**2) * self.woodbury_inv * Npdfs 250 | weights = torch.cat([w[w > 0], w_prime[w_prime < 0].abs(), torch.tensor(1).unsqueeze(0).to(self.device)]) 251 | 252 | idx = weights.argsort(descending=True)[:self.n_gaussians] 253 | S = weights[idx].sum() 254 | idx_all = (weights[idx] / S > self.threshold) 255 | S = weights[idx[idx_all]].sum() 256 | 257 | mu_pi = self.prior.loc 258 | cov_pi = self.prior.covariance_matrix 259 | X_ij_plus = (self.Xobs.unsqueeze(1) + self.Xobs.unsqueeze(0)) / 2 260 | thresh1 = w[w > 0].size(0) 261 | thresh2 = thresh1 + w_prime[w_prime < 0].size(0) 262 | 263 | idx_update = idx[idx_all] 264 | condition1 = idx_update < thresh1 265 | condition2 = (idx_update > thresh1) * (idx_update < thresh2) 266 | condition3 = idx_update == thresh2 267 | if condition1.any(): 268 | _idx = idx_update[condition1] 269 | w_A = weights[_idx] / S 270 | mu_A = self.Xobs[w > 0][_idx] 271 | cov_A = self.W 272 | else: 273 | w_A = torch.tensor([]) 274 | mu_A = torch.tensor([]) 275 | cov_A = torch.tensor([]) 276 | 277 | if condition2.any(): 278 | _idx = idx_update[condition2] - thresh1 279 | w_B = weights[idx_update[condition2]] / S 280 | mu_B = X_ij_plus[w_prime < 0][_idx] 281 | cov_B = self.W / 2 282 | else: 283 | w_B = torch.tensor([]) 284 | mu_B = torch.tensor([]) 285 | cov_B = torch.tensor([]) 286 | 287 | if condition3.any(): 288 | _idx = idx_update[condition3] 289 | w_C = weights[_idx] / S 290 | mu_C = mu_pi 291 | cov_C = cov_pi 292 | else: 293 | w_C = torch.tensor([]) 294 | mu_C = torch.tensor([]) 295 | cov_C = torch.tensor([]) 296 | return w_A, w_B, w_C, mu_A, mu_B, mu_C, cov_A, cov_B, cov_C 297 | 298 | def joint_pdf_mean(self, x): 299 | """ 300 | Args: 301 | - x: torch.tensor, inputs. torch.Size(n_data, n_dims) 302 | 303 | Returns: 304 | - pdf_sum: torch.tensor, the values of probability density function of approximated log (mu_g(x) pi(x)) 305 | """ 306 | d_x = len(x) 307 | pdf_sum = 0 308 | 309 | if not self.d_A == 0: 310 | x_A = (torch.tile(self.mu_A, (d_x, 1, 1)) - x.unsqueeze(1)).reshape( 311 | self.d_A * d_x, self.n_dims 312 | ) 313 | Npdfs_A = self.utils.safe_mvn_prob( 314 | torch.zeros(self.n_dims).to(self.device), 315 | self.cov_A, 316 | x_A, 317 | ).reshape(d_x, self.d_A) 318 | 319 | f_A = self.w_A.unsqueeze(0) * Npdfs_A 320 | pdf_sum += f_A.sum(axis=1) 321 | 322 | if not self.d_B == 0: 323 | x_B = (torch.tile(self.mu_B, (d_x, 1, 1)) - x.unsqueeze(1)).reshape( 324 | self.d_B * d_x, self.n_dims 325 | ) 326 | Npdfs_B = self.utils.safe_mvn_prob( 327 | torch.zeros(self.n_dims).to(self.device), 328 | self.cov_B, 329 | x_B, 330 | ).reshape(d_x, self.d_B) 331 | 332 | f_B = self.w_B.unsqueeze(0) * Npdfs_B 333 | pdf_sum += f_B.sum(axis=1) 334 | 335 | if not self.d_C == 0: 336 | Npdfs_C = self.utils.safe_mvn_prob( 337 | self.prior.loc, 338 | self.prior.covariance_matrix, 339 | x, 340 | ) 341 | f_C = self.w_C * Npdfs_C 342 | pdf_sum += f_C 343 | 344 | return pdf_sum 345 | 346 | def sampling_mean(self, n): 347 | """ 348 | Args: 349 | - n: int, number of samples to be sampled. 350 | 351 | Returns: 352 | - samples: torch.tensor, the samples from approximated GP mean functions 353 | """ 354 | samples = torch.tensor([]).to(self.device) 355 | 356 | if not self.d_A == 0: 357 | cnts = (n * self.w_A).type(torch.int) 358 | samples_A = torch.cat([ 359 | MultivariateNormal( 360 | self.mu_A[i], 361 | self.cov_A, 362 | ).sample(torch.Size([cnt])).to(self.device) 363 | for i, cnt in enumerate(cnts) 364 | ]) 365 | samples = torch.cat([samples, samples_A]) 366 | 367 | if not self.d_B == 0: 368 | cnts = (n * self.w_B).type(torch.int) 369 | samples_B = torch.cat([ 370 | MultivariateNormal( 371 | self.mu_B[i], 372 | self.cov_B, 373 | ).sample(torch.Size([cnt])).to(self.device) 374 | for i, cnt in enumerate(cnts) 375 | ]) 376 | samples = torch.cat([samples, samples_B]) 377 | 378 | if not self.d_C == 0: 379 | cnt = (n * self.w_C).type(torch.int) 380 | samples_C = MultivariateNormal( 381 | self.mu_C, 382 | self.cov_C, 383 | ).sample(torch.Size([cnt])).to(self.device) 384 | samples = torch.cat([samples, samples_C]) 385 | 386 | return samples 387 | 388 | def sparseGMM(self): 389 | """ 390 | Returns: 391 | - w_A: torch.tensor, the weights of Gaussians 392 | - mu_A: torch.tensor, the mean vector of Gaussians 393 | - cov_A: torch.tensor, the covariance matrix of Gaussians 394 | """ 395 | X_ij_plus = (self.Xobs.unsqueeze(1) + self.Xobs.unsqueeze(0)).reshape(self.n_data**2, self.n_dims) 396 | self.X_ij_plus_half = X_ij_plus / 2 397 | X_ij_minus = (self.Xobs.unsqueeze(1) - self.Xobs.unsqueeze(0)).reshape(self.n_data**2, self.n_dims) 398 | mu_pi = self.prior.loc 399 | cov_pi = self.prior.covariance_matrix 400 | sig_prime = (cov_pi.inverse() + 2 * self.W.inverse()).inverse() 401 | mu_prime = (sig_prime @ (cov_pi @ mu_pi + (self.W.inverse() @ X_ij_plus.T).T).T).T 402 | 403 | Npdfs_Xij_2W = self.utils.safe_mvn_prob( 404 | torch.zeros(self.n_dims).to(self.device), 405 | 2 * self.W, 406 | X_ij_minus, 407 | ).reshape(self.n_data, self.n_data) 408 | 409 | Npdfs_Xij_W_prior = self.utils.safe_mvn_prob( 410 | mu_pi, 411 | self.W / 2 + cov_pi, 412 | self.X_ij_plus_half, 413 | ).reshape(self.n_data, self.n_data) 414 | 415 | mij = (self.v**2) * self.woodbury_vector.unsqueeze(1) @ self.woodbury_vector.unsqueeze(0) * Npdfs_Xij_2W * Npdfs_Xij_W_prior 416 | Wij = (self.v**2) * self.woodbury_inv * Npdfs_Xij_2W 417 | mij = mij.reshape(mij.numel()) 418 | 419 | n_eff = int(1.1 * self.n_gaussians) 420 | idx = mij.argsort(descending=True)[:n_eff] 421 | 422 | self.Wij_flat = Wij.reshape(Wij.numel()) 423 | Mij = torch.squeeze(mij[idx] * (self.calc_Taylor(mu_prime[idx]) * torch.exp(self.lengthscale) - 1)) 424 | indice = idx[Mij.argsort(descending=True)][:self.n_gaussians] 425 | Mij = Mij[Mij.argsort(descending=True)][:self.n_gaussians] 426 | 427 | S_norm = Mij[Mij > 0].sum() 428 | if S_norm == 0: 429 | S_norm = 1 430 | w_A = Mij / S_norm 431 | 432 | idx_thresh = w_A > self.threshold 433 | if idx_thresh.sum() == 0: 434 | idx_thresh = torch.tensor([0]) 435 | w_A = torch.tensor([1]) 436 | else: 437 | w_A = w_A[idx_thresh] 438 | w_A /= w_A.sum() 439 | 440 | mu_A = mu_prime[indice][idx_thresh] 441 | cov_A = sig_prime 442 | 443 | return w_A, mu_A, cov_A 444 | 445 | def joint_pdf(self, x): 446 | """ 447 | Args: 448 | - x: torch.tensor, inputs. torch.Size(n_data, n_dims) 449 | 450 | Returns: 451 | - first/first+second: torch.tensor, the values of probability density function of approximated A(x) 452 | """ 453 | d_x = len(x) 454 | 455 | x_cov = (torch.tile(self.mu_cov, (d_x, 1, 1)) - x.unsqueeze(1)).reshape( 456 | self.d_cov * d_x, self.n_dims 457 | ) 458 | Npdfs_cov = self.utils.safe_mvn_prob( 459 | torch.zeros(self.n_dims).to(self.device), 460 | self.sig_cov, 461 | x_cov, 462 | ).reshape(d_x, self.d_cov) 463 | 464 | f_cov = self.w_cov.unsqueeze(0) * Npdfs_cov 465 | pdf_sum = f_cov.sum(axis=1) 466 | return pdf_sum 467 | 468 | def sampling(self, n): 469 | """ 470 | Args: 471 | - n: int, number of samples to be sampled. 472 | 473 | Returns: 474 | - samples: torch.tensor, the samples from approximated GP mean functions 475 | """ 476 | samples = torch.tensor([]) 477 | 478 | cnts = (n * self.w_cov).type(torch.int) 479 | samples = torch.cat([ 480 | MultivariateNormal( 481 | self.mu_cov[i], 482 | self.sig_cov, 483 | ).sample(torch.Size([cnt])).to(self.device) 484 | for i, cnt in enumerate(cnts) 485 | ]) 486 | return samples 487 | -------------------------------------------------------------------------------- /BASQ/_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ._gp import predict 3 | from ._utils import Utils 4 | from ._acquisition_function import SquareRootAcquisitionFunction, LogRootAcquisitionFunction 5 | 6 | 7 | class PriorSampler: 8 | def __init__(self, prior, n_rec, nys_ratio, device): 9 | """ 10 | Args: 11 | - prior: torch.distributions, prior distribution 12 | - n_rec: int, number of subsamples for empirical measure of kernel recomnbination 13 | - nys_ratio: float, subsubsampling ratio for Nystrom. 14 | - device: torch.device, cpu or cuda 15 | """ 16 | self.prior = prior 17 | self.n_rec = n_rec 18 | self.nys_ratio = nys_ratio 19 | self.device = device 20 | 21 | def __call__(self, n_rec): 22 | """ 23 | Args: 24 | - n_rec: int, number of subsamples for empirical measure of kernel recomnbination 25 | 26 | Returns: 27 | - pts_nys: torch.tensor, subsamples for low-rank approximation via Nyström method 28 | - pts_rec: torch.tensor, subsamples for empirical measure of kernel recomnbination 29 | - w_IS: torch.tensor, weights for importance sampling if pts_rec is not sampled from the prior 30 | """ 31 | pts_rec = self.prior.sample(sample_shape=torch.Size([n_rec])).to(self.device) 32 | pts_nys = pts_rec[:int(self.n_rec * self.nys_ratio)] 33 | w = torch.ones(n_rec) / n_rec 34 | return pts_nys, pts_rec, w.to(self.device) 35 | 36 | 37 | class UncertaintySampler(SquareRootAcquisitionFunction): 38 | def __init__( 39 | self, 40 | prior, 41 | model, 42 | n_rec, 43 | nys_ratio, 44 | device, 45 | sampling_method="approx", 46 | ratio=0.5, 47 | ratio_super=100, 48 | n_gaussians=100, 49 | threshold=1e-5, 50 | ): 51 | """ 52 | Args: 53 | - prior: torch.distributions, prior distribution. 54 | - model: gpytorch.models, function of GP model 55 | - n_rec: int, subsampling size for kernel recombination 56 | - nys_ratio: float, subsubsampling ratio for Nystrom. Number of Nystrom samples is nys_ratio * n_rec 57 | - device: torch.device, cpu or cuda 58 | - sampling_method; string, ["exact", "approx"] 59 | - ratio: float, mixing ratio of prior and uncertainty sampling, 0 < r < 1 60 | - n_gaussians: int, number of Gaussians approximating the GP-modelled acquisition function 61 | - threshold: float, threshold to cut off the insignificant Gaussians 62 | """ 63 | super().__init__(prior, model, device, n_gaussians=n_gaussians, threshold=threshold) 64 | self.model = model 65 | self.ratio = ratio 66 | self.nys_ratio = nys_ratio 67 | self.ratio_super = ratio_super 68 | self.device = device 69 | self.sampling_method = sampling_method 70 | self.utils = Utils(device) 71 | 72 | def pdf(self, X): 73 | """ 74 | Args: 75 | - X: torch.tensor, inputs 76 | 77 | Returns: 78 | - pdf: the value at given X of probability density function of approximated sparse Gaussian Mixture Model (GMM) 79 | """ 80 | if self.ratio == 0: 81 | return self.prior.log_prob(X).exp() 82 | elif self.ratio == 1: 83 | return self.joint_pdf(X) 84 | else: 85 | g_pdf = self.joint_pdf(X) 86 | f_pdf = self.prior.log_prob(X).exp() 87 | return ((1 - self.ratio) * f_pdf + self.ratio * g_pdf) / f_pdf 88 | 89 | def SIR(self, X, weights, n_return): 90 | """ 91 | Sequentail Importance Resample (SIR). 92 | Resample from the weighted samples via importance sampling. 93 | 94 | Args: 95 | - X: torch.tensor, inputs 96 | - weights: torch.tensor, weights for importance sampling. This is not necessarily required to be normalised. 97 | - n_return: torch.tensor, inputs, number of samples to be returned. 98 | 99 | Returns: 100 | - samples: resampled samples. 101 | """ 102 | draw = torch.multinomial(weights, n_return) 103 | return X[draw] 104 | 105 | def approx(self, n): 106 | """ 107 | Proposal distribution g(x) = (1-r) f(x) + r A(x), 108 | f(x) = π(x), 109 | where r is self.ratio, m(x) and C(x) are the mean and varinace of square-root kernel. 110 | weights w_IS = f(x) / g(x) 111 | samples for Nystrom should be sampled from f(x), thus we adopt SIR. 112 | pts_nys <- SIR(pts_rec, weights) is equivalent to be sampled from f(x). 113 | 114 | Args: 115 | - n: int, number of samples to be returned 116 | 117 | Returns: 118 | - pts_nys: torch.tensor, subsamples for low-rank approximation via Nyström method 119 | - pts_rec: torch.tensor, subsamples for empirical measure of kernel recomnbination 120 | - w_IS: torch.tensor, weights for importance sampling if pts_rec is not sampled from the prior 121 | """ 122 | if self.ratio == 0: 123 | pts_rec = self.prior.sample(torch.Size([n])) 124 | elif self.ratio == 1: 125 | pts_rec = self.sampling(n) 126 | else: 127 | pts_rec = torch.cat([ 128 | self.sampling(int(self.ratio * n)), 129 | self.prior.sample(torch.Size([int((1 - self.ratio) * n)])), 130 | ]) 131 | 132 | mean, var = predict(pts_rec, self.model) 133 | w = torch.exp(torch.log(torch.abs(mean)) + self.prior.log_prob(pts_rec) - torch.nan_to_num(self.pdf(pts_rec))) 134 | w = torch.nan_to_num(w) 135 | if torch.sum(w) == 0: 136 | weights = torch.ones(len(w)) / len(w) 137 | else: 138 | weights = w / torch.sum(w) 139 | 140 | n_nys = int(n * self.nys_ratio) 141 | pts_nys = self.SIR(pts_rec, weights, n_nys) 142 | return pts_nys, pts_rec, weights 143 | 144 | def SIR_from_mean(self, n_super, n): 145 | """ 146 | Proposal distribution g(x) = B(x), 147 | f(x) = |m(x)| π(x), 148 | weights w_IS = f(x) / g(x) 149 | 150 | Args: 151 | - n_super: int, number of supersamples for SIR 152 | - n: int, number of samples to be returned 153 | 154 | Returns: 155 | - samples: resampled samples. 156 | """ 157 | X_pi = self.sampling_mean(n_super) 158 | mean, _ = predict(X_pi, self.model) 159 | mean_log = mean.abs().log() 160 | prior_log = self.utils.safe_mvn_prob(self.prior.loc, self.prior.covariance_matrix, X_pi).log() 161 | sampler_log = torch.nan_to_num(self.joint_pdf_mean(X_pi)) 162 | 163 | w_mpi_B = torch.exp( 164 | mean_log + prior_log - sampler_log 165 | ) 166 | X_f = self.SIR(X_pi, w_mpi_B, n) 167 | return X_f 168 | 169 | def SIR_from_AF(self, n_super, n): 170 | """ 171 | Proposal distribution g(x) = A(x), 172 | f(x) = C(x)π(x), 173 | weights w_IS = f(x) / g(x) 174 | 175 | Args: 176 | - n_super: int, number of supersamples for SIR 177 | - n: int, number of samples to be returned 178 | 179 | Returns: 180 | - samples: resampled samples. 181 | """ 182 | X_A = self.sampling(n_super) 183 | _, var_A = predict(X_A, self.model) 184 | var_log = var_A.log() 185 | prior_log = self.utils.safe_mvn_prob(self.prior.loc, self.prior.covariance_matrix, X_A).log() 186 | sampler_log = torch.nan_to_num(self.joint_pdf(X_A)).log() 187 | 188 | w_C_A = torch.exp( 189 | var_log + prior_log - sampler_log 190 | ) 191 | X_rec = self.SIR(X_A, w_C_A, n) 192 | return X_rec 193 | 194 | def calc_weights(self, pts_rec): 195 | """ 196 | weights w_IS = f(x) / g(x) 197 | g(x) = (1-r) f(x) + r var(x)π(x), 198 | f(x) = |m(x)| π(x) 199 | warped GP(m(x), var(x,x)) 200 | 201 | Args: 202 | - pts_rec: torch.tensor, subsamples for empirical measure of kernel recomnbination 203 | 204 | Returns: 205 | - w_IS: torch.tensor, weights for importance sampling if pts_rec is not sampled from the prior 206 | """ 207 | mean_rec, var_rec = predict(pts_rec, self.model) 208 | f_rec = torch.exp(torch.abs(mean_rec).log() + self.prior.log_prob(pts_rec)) 209 | if self.ratio < 1: 210 | g_rec = torch.exp( 211 | torch.log(self.ratio * var_rec + (1 - self.ratio) * torch.abs(mean_rec)) 212 | + self.prior.log_prob(pts_rec) 213 | ) 214 | else: 215 | g_rec = torch.exp( 216 | torch.tensor(self.ratio).log() + var_rec.log() + self.prior.log_prob(pts_rec) 217 | ) 218 | w_IC = f_rec / g_rec 219 | w_IC = w_IC / w_IC.sum() 220 | return w_IC 221 | 222 | def exact(self, n): 223 | """ 224 | warped GP(m(x), var(x,x)) 225 | Proposal distribution g(x) = (1-r) f(x) + r C(x)π(x), 226 | f(x) = |m(x)| π(x), 227 | where r is self.ratio, m(x) and C(x) are the mean and varinace of square-root kernel. 228 | weights w_IS = f(x) / g(x) 229 | samples for Nystrom should be sampled from f(x), thus we adopt SIR. 230 | pts_nys <- SIR(pts_rec, weights) is equivalent to be sampled from f(x). 231 | If r = 0, this simply samples from f(x) = |m(x)| π(x). 232 | If r = 1, this becomes pure uncertainty sampling. 233 | 234 | Args: 235 | - n: int, number of samples to be returned 236 | 237 | Returns: 238 | - pts_nys: torch.tensor, subsamples for low-rank approximation via Nyström method 239 | - pts_rec: torch.tensor, subsamples for empirical measure of kernel recomnbination 240 | - w_IS: torch.tensor, weights for importance sampling if pts_rec is not sampled from the prior 241 | """ 242 | n_nys = int(n * self.nys_ratio) 243 | 244 | if self.ratio == 0: 245 | n_super = int(self.ratio_super * n) 246 | pts_rec = self.SIR_from_mean(n_super, n) 247 | pts_nys = pts_rec[:n_nys] 248 | w_IC = torch.ones(n).to(self.device) / n 249 | return pts_nys, pts_rec, w_IC 250 | elif self.ratio == 1: 251 | n_super = int(self.ratio_super * n) 252 | pts_rec = self.SIR_from_AF(n_super, n) 253 | w_IC = self.calc_weights(pts_rec) 254 | pts_nys = self.SIR_from_mean(n, n_nys) 255 | return pts_nys, pts_rec, w_IC 256 | else: 257 | n_super = int(self.ratio_super * (1 - self.ratio) * n) 258 | n_pi = int((1 - self.ratio) * n) 259 | X_f = self.SIR_from_mean(n_super, n_pi) 260 | pts_nys = X_f[:n_nys] 261 | 262 | n_super = int(self.ratio_super * self.ratio * n) 263 | n_rec = int(self.ratio * n) 264 | X_rec = self.SIR_from_AF(n_super, n_rec) 265 | pts_rec = torch.cat([X_rec, X_f]) 266 | w_IC = self.calc_weights(pts_rec) 267 | return pts_nys, pts_rec, w_IC 268 | 269 | def __call__(self, n): 270 | """ 271 | Args: 272 | - n: int, number of samples to be returned 273 | 274 | Returns: 275 | - pts_nys: torch.tensor, subsamples for low-rank approximation via Nyström method 276 | - pts_rec: torch.tensor, subsamples for empirical measure of kernel recomnbination 277 | - w_IS: torch.tensor, weights for importance sampling if pts_rec is not sampled from the prior 278 | """ 279 | if self.sampling_method == "approx": 280 | return self.approx(n) 281 | elif self.sampling_method == "exact": 282 | return self.exact(n) 283 | else: 284 | raise Exception("The given sampling method is undefined.") 285 | 286 | 287 | class LogUncertaintySampler(LogRootAcquisitionFunction): 288 | def __init__( 289 | self, 290 | prior, 291 | gp, 292 | n_rec, 293 | nys_ratio, 294 | device, 295 | sampling_method="approx", 296 | ratio=0.5, 297 | ratio_super=100, 298 | n_gaussians=100, 299 | threshold=1e-5, 300 | ): 301 | """ 302 | Args: 303 | - prior: torch.distributions, prior distribution. 304 | - model: gpytorch.models, function of GP model 305 | - n_rec: int, subsampling size for kernel recombination 306 | - nys_ratio: float, subsubsampling ratio for Nystrom. Number of Nystrom samples is nys_ratio * n_rec 307 | - device: torch.device, cpu or cuda 308 | - sampling_method; string, ["exact", "approx"] 309 | - ratio: float, mixing ratio of prior and uncertainty sampling, 0 < r < 1 310 | - n_gaussians: int, number of Gaussians approximating the GP-modelled acquisition function 311 | - threshold: float, threshold to cut off the insignificant Gaussians 312 | """ 313 | super().__init__(prior, gp.model, device, n_gaussians=n_gaussians, threshold=threshold) 314 | self.gp = gp 315 | self.ratio = ratio 316 | self.nys_ratio = nys_ratio 317 | self.ratio_super = ratio_super 318 | self.device = device 319 | self.utils = Utils(device) 320 | 321 | def SIR(self, X, weights, n_return): 322 | """ 323 | Sequentail Importance Resample (SIR). 324 | Resample from the weighted samples via importance sampling. 325 | 326 | Args: 327 | - X: torch.tensor, inputs 328 | - weights: torch.tensor, weights for importance sampling. This is not necessarily required to be normalised. 329 | - n_return: torch.tensor, inputs, number of samples to be returned. 330 | 331 | Returns: 332 | - samples: resampled samples. 333 | """ 334 | draw = torch.multinomial(weights, n_return) 335 | return X[draw] 336 | 337 | def SIR_from_mean(self, n_super, n): 338 | """ 339 | Proposal distribution g(x) = B(x), 340 | f(x) = |mu_g(x)| π(x), 341 | weights w_IS = f(x) / g(x) 342 | 343 | Args: 344 | - n_super: int, number of supersamples for SIR 345 | - n: int, number of samples to be returned 346 | 347 | Returns: 348 | - samples: resampled samples. 349 | """ 350 | X_pi = self.sampling_mean(n_super) 351 | mu_g = self.gp.gspace_mean_predict(X_pi) 352 | mu_log = mu_g.abs().log() 353 | prior_log = self.utils.safe_mvn_prob(self.prior.loc, self.prior.covariance_matrix, X_pi).log() 354 | sampler_log = torch.nan_to_num(self.joint_pdf_mean(X_pi)).log() 355 | 356 | w_mpi_B = torch.exp( 357 | mu_log + prior_log - sampler_log 358 | ) 359 | X_f = self.SIR(X_pi, w_mpi_B, n) 360 | return X_f 361 | 362 | def SIR_from_AF(self, n_super, n): 363 | """ 364 | Proposal distribution g(x) = A(x), 365 | f(x) = var_g(x)π(x), 366 | weights w_IS = f(x) / g(x) 367 | 368 | Args: 369 | - n_super: int, number of supersamples for SIR 370 | - n: int, number of samples to be returned 371 | 372 | Returns: 373 | - samples: resampled samples. 374 | """ 375 | X_A = self.sampling(n_super) 376 | _, var_A = self.gp.gspace_predict(X_A) 377 | 378 | var_log = var_A.log() 379 | prior_log = self.utils.safe_mvn_prob(self.prior.loc, self.prior.covariance_matrix, X_A).log() 380 | sampler_log = torch.nan_to_num(self.joint_pdf(X_A)).log() 381 | 382 | w_C_A = torch.exp( 383 | var_log + prior_log - sampler_log 384 | ) 385 | 386 | X_rec = self.SIR(X_A, w_C_A, n) 387 | return X_rec 388 | 389 | def calc_weights(self, pts_rec): 390 | """ 391 | gspace GP(mu_g(x), var_g(x,x)) 392 | weights w_IS = f(x) / g(x) 393 | g(x) = (1-r) f(x) + r var_g(x)π(x), 394 | f(x) = |mu_g(x)| π(x) 395 | 396 | Args: 397 | - pts_rec: torch.tensor, subsamples for empirical measure of kernel recomnbination 398 | 399 | Returns: 400 | - w_IS: torch.tensor, weights for importance sampling if pts_rec is not sampled from the prior 401 | """ 402 | mu_g_rec, var_g_rec = self.gp.gspace_predict(pts_rec) 403 | f_rec = torch.exp(torch.abs(mu_g_rec).log() + self.prior.log_prob(pts_rec)) 404 | if self.ratio < 1: 405 | g_rec = torch.exp( 406 | torch.log(self.ratio * var_g_rec + (1 - self.ratio) * torch.abs(mu_g_rec)) 407 | + self.prior.log_prob(pts_rec) 408 | ) 409 | else: 410 | g_rec = torch.exp( 411 | torch.tensor(self.ratio).log() + var_g_rec.log() + self.prior.log_prob(pts_rec) 412 | ) 413 | w_IC = f_rec / g_rec 414 | w_IC = w_IC / w_IC.sum() 415 | return w_IC 416 | 417 | def exact(self, n): 418 | """ 419 | Proposal distribution g(x) = (1-r) f(x) + r var_g(x)π(x), 420 | f(x) = |mu_g(x)| π(x), 421 | where r is self.ratio, m(x) and C(x) are the mean and varinace of square-root kernel. 422 | weights w_IS = f(x) / g(x) 423 | samples for Nystrom should be sampled from f(x), thus we adopt SIR. 424 | pts_nys <- SIR(pts_rec, weights) is equivalent to be sampled from f(x). 425 | If r = 0, this simply samples from f(x) = |m(x)| π(x). 426 | If r = 1, this becomes pure uncertainty sampling. 427 | 428 | Args: 429 | - n: int, number of samples to be returned 430 | 431 | Returns: 432 | - pts_nys: torch.tensor, subsamples for low-rank approximation via Nyström method 433 | - pts_rec: torch.tensor, subsamples for empirical measure of kernel recomnbination 434 | - w_IS: torch.tensor, weights for importance sampling if pts_rec is not sampled from the prior 435 | """ 436 | n_nys = int(n * self.nys_ratio) 437 | 438 | if self.ratio == 0: 439 | n_super = int(self.ratio_super * n) 440 | pts_rec = self.SIR_from_mean(n_super, n) 441 | pts_nys = pts_rec[:n_nys] 442 | w_IC = torch.ones(n).to(self.device) / n 443 | return pts_nys, pts_rec, w_IC 444 | elif self.ratio == 1: 445 | n_super = int(self.ratio_super * n) 446 | pts_rec = self.SIR_from_AF(n_super, n) 447 | w_IC = self.calc_weights(pts_rec) 448 | pts_nys = self.SIR_from_mean(n, n_nys) 449 | return pts_nys, pts_rec, w_IC 450 | else: 451 | n_super = int(self.ratio_super * (1 - self.ratio) * n) 452 | n_pi = int((1 - self.ratio) * n) 453 | X_f = self.SIR_from_mean(n_super, n_pi) 454 | pts_nys = X_f[:n_nys] 455 | 456 | n_super = int(self.ratio_super * self.ratio * n) 457 | n_rec = int(self.ratio * n) 458 | X_rec = self.SIR_from_AF(n_super, n_rec) 459 | pts_rec = torch.cat([X_rec, X_f]) 460 | w_IC = self.calc_weights(pts_rec) 461 | return pts_nys, pts_rec, w_IC 462 | 463 | def __call__(self, n): 464 | """ 465 | Args: 466 | - n: int, number of samples to be returned 467 | 468 | Returns: 469 | - pts_nys: torch.tensor, subsamples for low-rank approximation via Nyström method 470 | - pts_rec: torch.tensor, subsamples for empirical measure of kernel recomnbination 471 | - w_IS: torch.tensor, weights for importance sampling if pts_rec is not sampled from the prior 472 | """ 473 | return self.exact(n) 474 | -------------------------------------------------------------------------------- /BASQ/_lbfgs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from functools import reduce 5 | from copy import deepcopy 6 | from torch.optim import Optimizer 7 | 8 | 9 | def is_legal(v): 10 | """ 11 | Checks that tensor is not NaN or Inf. 12 | 13 | Args: 14 | v (tensor): tensor to be checked 15 | 16 | """ 17 | legal = not torch.isnan(v).any() and not torch.isinf(v) 18 | 19 | return legal 20 | 21 | 22 | def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False): 23 | """ 24 | Gives the minimizer and minimum of the interpolating polynomial over given points 25 | based on function and derivative information. Defaults to bisection if no critical 26 | points are valid. 27 | 28 | Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight 29 | modifications. 30 | 31 | Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere 32 | Last edited 12/6/18. 33 | 34 | Args: 35 | points (nparray): two-dimensional array with each point of form [x f g] 36 | x_min_bound (float): minimum value that brackets minimum (default: minimum of points) 37 | x_max_bound (float): maximum value that brackets minimum (default: maximum of points) 38 | plot (bool): plot interpolating polynomial 39 | 40 | Returns: 41 | x_sol (float): minimizer of interpolating polynomial 42 | F_min (float): minimum of interpolating polynomial 43 | 44 | Note: 45 | . Set f or g to np.nan if they are unknown 46 | 47 | """ 48 | no_points = points.shape[0] 49 | order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1 50 | 51 | x_min = np.min(points[:, 0]) 52 | x_max = np.max(points[:, 0]) 53 | 54 | # compute bounds of interpolation area 55 | if x_min_bound is None: 56 | x_min_bound = x_min 57 | if x_max_bound is None: 58 | x_max_bound = x_max 59 | 60 | # explicit formula for quadratic interpolation 61 | if no_points == 2 and order == 2 and plot is False: 62 | # Solution to quadratic interpolation is given by: 63 | # a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2 64 | # x_min = x1 - g1/(2a) 65 | # if x1 = 0, then is given by: 66 | # x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2)) 67 | 68 | if points[0, 0] == 0: 69 | x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0])) 70 | else: 71 | a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2 72 | x_sol = points[0, 0] - points[0, 2] / (2 * a) 73 | 74 | x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound) 75 | 76 | # explicit formula for cubic interpolation 77 | elif no_points == 2 and order == 3 and plot is False: 78 | # Solution to cubic interpolation is given by: 79 | # d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2)) 80 | # d2 = sqrt(d1^2 - g1*g2) 81 | # x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)) 82 | d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0])) 83 | d2 = np.sqrt(d1 ** 2 - points[0, 2] * points[1, 2]) 84 | if np.isreal(d2): 85 | x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2)) 86 | x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound) 87 | else: 88 | x_sol = (x_max_bound + x_min_bound) / 2 89 | 90 | # solve linear system 91 | else: 92 | # define linear constraints 93 | A = np.zeros((0, order + 1)) 94 | b = np.zeros((0, 1)) 95 | 96 | # add linear constraints on function values 97 | for i in range(no_points): 98 | if not np.isnan(points[i, 1]): 99 | constraint = np.zeros((1, order + 1)) 100 | for j in range(order, -1, -1): 101 | constraint[0, order - j] = points[i, 0] ** j 102 | A = np.append(A, constraint, 0) 103 | b = np.append(b, points[i, 1]) 104 | 105 | # add linear constraints on gradient values 106 | for i in range(no_points): 107 | if not np.isnan(points[i, 2]): 108 | constraint = np.zeros((1, order + 1)) 109 | for j in range(order): 110 | constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1) 111 | A = np.append(A, constraint, 0) 112 | b = np.append(b, points[i, 2]) 113 | 114 | # check if system is solvable 115 | if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]: 116 | x_sol = (x_min_bound + x_max_bound) / 2 117 | f_min = np.Inf 118 | else: 119 | # solve linear system for interpolating polynomial 120 | coeff = np.linalg.solve(A, b) 121 | 122 | # compute critical points 123 | dcoeff = np.zeros(order) 124 | for i in range(len(coeff) - 1): 125 | dcoeff[i] = coeff[i] * (order - i) 126 | 127 | crit_pts = np.array([x_min_bound, x_max_bound]) 128 | crit_pts = np.append(crit_pts, points[:, 0]) 129 | 130 | if not np.isinf(dcoeff).any(): 131 | roots = np.roots(dcoeff) 132 | crit_pts = np.append(crit_pts, roots) 133 | 134 | # test critical points 135 | f_min = np.Inf 136 | x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection 137 | for crit_pt in crit_pts: 138 | if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound: 139 | F_cp = np.polyval(coeff, crit_pt) 140 | if np.isreal(F_cp) and F_cp < f_min: 141 | x_sol = np.real(crit_pt) 142 | f_min = np.real(F_cp) 143 | 144 | if plot: 145 | plt.figure() 146 | x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound) / 10000) 147 | f = np.polyval(coeff, x) 148 | plt.plot(x, f) 149 | plt.plot(x_sol, f_min, 'x') 150 | 151 | return x_sol 152 | 153 | 154 | class LBFGS(Optimizer): 155 | """ 156 | Implements the L-BFGS algorithm. Compatible with multi-batch and full-overlap 157 | L-BFGS implementations and (stochastic) Powell damping. Partly based on the 158 | original L-BFGS implementation in PyTorch, Mark Schmidt's minFunc MATLAB code, 159 | and Michael Overton's weak Wolfe line search MATLAB code. 160 | 161 | Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere 162 | Last edited 10/20/20. 163 | 164 | Warnings: 165 | . Does not support per-parameter options and parameter groups. 166 | . All parameters have to be on a single device. 167 | 168 | Args: 169 | lr (float): steplength or learning rate (default: 1) 170 | history_size (int): update history size (default: 10) 171 | line_search (str): designates line search to use (default: 'Wolfe') 172 | Options: 173 | 'None': uses steplength designated in algorithm 174 | 'Armijo': uses Armijo backtracking line search 175 | 'Wolfe': uses Armijo-Wolfe bracketing line search 176 | dtype: data type (default: torch.float) 177 | debug (bool): debugging mode 178 | 179 | References: 180 | [1] Berahas, Albert S., Jorge Nocedal, and Martin Takác. "A Multi-Batch L-BFGS 181 | Method for Machine Learning." Advances in Neural Information Processing 182 | Systems. 2016. 183 | [2] Bollapragada, Raghu, et al. "A Progressive Batching L-BFGS Method for Machine 184 | Learning." International Conference on Machine Learning. 2018. 185 | [3] Lewis, Adrian S., and Michael L. Overton. "Nonsmooth Optimization via Quasi-Newton 186 | Methods." Mathematical Programming 141.1-2 (2013): 135-163. 187 | [4] Liu, Dong C., and Jorge Nocedal. "On the Limited Memory BFGS Method for 188 | Large Scale Optimization." Mathematical Programming 45.1-3 (1989): 503-528. 189 | [5] Nocedal, Jorge. "Updating Quasi-Newton Matrices With Limited Storage." 190 | Mathematics of Computation 35.151 (1980): 773-782. 191 | [6] Nocedal, Jorge, and Stephen J. Wright. "Numerical Optimization." Springer New York, 192 | 2006. 193 | [7] Schmidt, Mark. "minFunc: Unconstrained Differentiable Multivariate Optimization 194 | in Matlab." Software available at http://www.cs.ubc.ca/~schmidtm/Software/minFunc.html 195 | (2005). 196 | [8] Schraudolph, Nicol N., Jin Yu, and Simon Günter. "A Stochastic Quasi-Newton 197 | Method for Online Convex Optimization." Artificial Intelligence and Statistics. 198 | 2007. 199 | [9] Wang, Xiao, et al. "Stochastic Quasi-Newton Methods for Nonconvex Stochastic 200 | Optimization." SIAM Journal on Optimization 27.2 (2017): 927-956. 201 | 202 | """ 203 | 204 | def __init__(self, params, lr=1., history_size=10, line_search='Wolfe', 205 | dtype=torch.float, debug=False): 206 | 207 | # ensure inputs are valid 208 | if not 0.0 <= lr: 209 | raise ValueError("Invalid learning rate: {}".format(lr)) 210 | if not 0 <= history_size: 211 | raise ValueError("Invalid history size: {}".format(history_size)) 212 | if line_search not in ['Armijo', 'Wolfe', 'None']: 213 | raise ValueError("Invalid line search: {}".format(line_search)) 214 | 215 | defaults = dict(lr=lr, history_size=history_size, line_search=line_search, dtype=dtype, debug=debug) 216 | super(LBFGS, self).__init__(params, defaults) 217 | 218 | if len(self.param_groups) != 1: 219 | raise ValueError("L-BFGS doesn't support per-parameter options " 220 | "(parameter groups)") 221 | 222 | self._params = self.param_groups[0]['params'] 223 | self._numel_cache = None 224 | 225 | state = self.state['global_state'] 226 | state.setdefault('n_iter', 0) 227 | state.setdefault('curv_skips', 0) 228 | state.setdefault('fail_skips', 0) 229 | state.setdefault('H_diag', 1) 230 | state.setdefault('fail', True) 231 | 232 | state['old_dirs'] = [] 233 | state['old_stps'] = [] 234 | 235 | def _numel(self): 236 | if self._numel_cache is None: 237 | self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) 238 | return self._numel_cache 239 | 240 | def _gather_flat_grad(self): 241 | views = [] 242 | for p in self._params: 243 | if p.grad is None: 244 | view = p.data.new(p.data.numel()).zero_() 245 | elif p.grad.data.is_sparse: 246 | view = p.grad.data.to_dense().view(-1) 247 | else: 248 | view = p.grad.data.view(-1) 249 | views.append(view) 250 | return torch.cat(views, 0) 251 | 252 | def _add_update(self, step_size, update): 253 | offset = 0 254 | for p in self._params: 255 | numel = p.numel() 256 | # view as to avoid deprecated pointwise semantics 257 | p.data.add_(step_size, update[offset:offset + numel].view_as(p.data)) 258 | offset += numel 259 | assert offset == self._numel() 260 | 261 | def _copy_params(self): 262 | current_params = [] 263 | for param in self._params: 264 | current_params.append(deepcopy(param.data)) 265 | return current_params 266 | 267 | def _load_params(self, current_params): 268 | i = 0 269 | for param in self._params: 270 | param.data[:] = current_params[i] 271 | i += 1 272 | 273 | def line_search(self, line_search): 274 | """ 275 | Switches line search option. 276 | 277 | Args: 278 | line_search (str): designates line search to use 279 | Options: 280 | 'None': uses steplength designated in algorithm 281 | 'Armijo': uses Armijo backtracking line search 282 | 'Wolfe': uses Armijo-Wolfe bracketing line search 283 | 284 | """ 285 | 286 | group = self.param_groups[0] 287 | group['line_search'] = line_search 288 | 289 | return 290 | 291 | def two_loop_recursion(self, vec): 292 | """ 293 | Performs two-loop recursion on given vector to obtain Hv. 294 | 295 | Args: 296 | vec (tensor): 1-D tensor to apply two-loop recursion to 297 | 298 | Returns: 299 | r (tensor): matrix-vector product Hv 300 | 301 | """ 302 | 303 | group = self.param_groups[0] 304 | history_size = group['history_size'] 305 | 306 | state = self.state['global_state'] 307 | old_dirs = state.get('old_dirs') # change in gradients 308 | old_stps = state.get('old_stps') # change in iterates 309 | H_diag = state.get('H_diag') 310 | 311 | # compute the product of the inverse Hessian approximation and the gradient 312 | num_old = len(old_dirs) 313 | 314 | if 'rho' not in state: 315 | state['rho'] = [None] * history_size 316 | state['alpha'] = [None] * history_size 317 | rho = state['rho'] 318 | alpha = state['alpha'] 319 | 320 | for i in range(num_old): 321 | rho[i] = 1. / old_stps[i].dot(old_dirs[i]) 322 | 323 | q = vec 324 | for i in range(num_old - 1, -1, -1): 325 | alpha[i] = old_dirs[i].dot(q) * rho[i] 326 | q.add_(-alpha[i], old_stps[i]) 327 | 328 | # multiply by initial Hessian 329 | # r/d is the final direction 330 | r = torch.mul(q, H_diag) 331 | for i in range(num_old): 332 | beta = old_stps[i].dot(r) * rho[i] 333 | r.add_(alpha[i] - beta, old_dirs[i]) 334 | 335 | return r 336 | 337 | def curvature_update(self, flat_grad, eps=1e-2, damping=False): 338 | """ 339 | Performs curvature update. 340 | 341 | Args: 342 | flat_grad (tensor): 1-D tensor of flattened gradient for computing 343 | gradient difference with previously stored gradient 344 | eps (float): constant for curvature pair rejection or damping (default: 1e-2) 345 | damping (bool): flag for using Powell damping (default: False) 346 | """ 347 | 348 | assert len(self.param_groups) == 1 349 | 350 | # load parameters 351 | if eps <= 0: 352 | raise ValueError('Invalid eps; must be positive.') 353 | 354 | group = self.param_groups[0] 355 | history_size = group['history_size'] 356 | debug = group['debug'] 357 | 358 | # variables cached in state (for tracing) 359 | state = self.state['global_state'] 360 | fail = state.get('fail') 361 | 362 | # check if line search failed 363 | if not fail: 364 | 365 | d = state.get('d') 366 | t = state.get('t') 367 | old_dirs = state.get('old_dirs') 368 | old_stps = state.get('old_stps') 369 | H_diag = state.get('H_diag') 370 | prev_flat_grad = state.get('prev_flat_grad') 371 | Bs = state.get('Bs') 372 | 373 | # compute y's 374 | y = flat_grad.sub(prev_flat_grad) 375 | s = d.mul(t) 376 | sBs = s.dot(Bs) 377 | ys = y.dot(s) # y*s 378 | 379 | # update L-BFGS matrix 380 | if ys > eps * sBs or damping is True: 381 | 382 | # perform Powell damping 383 | if damping is True and ys < eps * sBs: 384 | if debug: 385 | print('Applying Powell damping...') 386 | theta = ((1 - eps) * sBs) / (sBs - ys) 387 | y = theta * y + (1 - theta) * Bs 388 | 389 | # updating memory 390 | if len(old_dirs) == history_size: 391 | # shift history by one (limited-memory) 392 | old_dirs.pop(0) 393 | old_stps.pop(0) 394 | 395 | # store new direction/step 396 | old_dirs.append(s) 397 | old_stps.append(y) 398 | 399 | # update scale of initial Hessian approximation 400 | H_diag = ys / y.dot(y) # (y*y) 401 | 402 | state['old_dirs'] = old_dirs 403 | state['old_stps'] = old_stps 404 | state['H_diag'] = H_diag 405 | 406 | else: 407 | # save skip 408 | state['curv_skips'] += 1 409 | if debug: 410 | print('Curvature pair skipped due to failed criterion') 411 | 412 | else: 413 | # save skip 414 | state['fail_skips'] += 1 415 | if debug: 416 | print('Line search failed; curvature pair update skipped') 417 | 418 | return 419 | 420 | def _step(self, p_k, g_Ok, g_Sk=None, options=None): 421 | """ 422 | Performs a single optimization step. 423 | 424 | Args: 425 | p_k (tensor): 1-D tensor specifying search direction 426 | g_Ok (tensor): 1-D tensor of flattened gradient over overlap O_k used 427 | for gradient differencing in curvature pair update 428 | g_Sk (tensor): 1-D tensor of flattened gradient over full sample S_k 429 | used for curvature pair damping or rejection criterion, 430 | if None, will use g_Ok (default: None) 431 | options (dict): contains options for performing line search (default: None) 432 | 433 | Options for Armijo backtracking line search: 434 | 'closure' (callable): reevaluates model and returns function value 435 | 'current_loss' (tensor): objective value at current iterate (default: F(x_k)) 436 | 'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd) 437 | 'eta' (tensor): factor for decreasing steplength > 0 (default: 2) 438 | 'c1' (tensor): sufficient decrease constant in (0, 1) (default: 1e-4) 439 | 'max_ls' (int): maximum number of line search steps permitted (default: 10) 440 | 'interpolate' (bool): flag for using interpolation (default: True) 441 | 'inplace' (bool): flag for inplace operations (default: True) 442 | 'ls_debug' (bool): debugging mode for line search 443 | 444 | Options for Wolfe line search: 445 | 'closure' (callable): reevaluates model and returns function value 446 | 'current_loss' (tensor): objective value at current iterate (default: F(x_k)) 447 | 'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd) 448 | 'eta' (float): factor for extrapolation (default: 2) 449 | 'c1' (float): sufficient decrease constant in (0, 1) (default: 1e-4) 450 | 'c2' (float): curvature condition constant in (0, 1) (default: 0.9) 451 | 'max_ls' (int): maximum number of line search steps permitted (default: 10) 452 | 'interpolate' (bool): flag for using interpolation (default: True) 453 | 'inplace' (bool): flag for inplace operations (default: True) 454 | 'ls_debug' (bool): debugging mode for line search 455 | 456 | Returns (depends on line search): 457 | . No line search: 458 | t (float): steplength 459 | . Armijo backtracking line search: 460 | F_new (tensor): loss function at new iterate 461 | t (tensor): final steplength 462 | ls_step (int): number of backtracks 463 | closure_eval (int): number of closure evaluations 464 | desc_dir (bool): descent direction flag 465 | True: p_k is descent direction with respect to the line search 466 | function 467 | False: p_k is not a descent direction with respect to the line 468 | search function 469 | fail (bool): failure flag 470 | True: line search reached maximum number of iterations, failed 471 | False: line search succeeded 472 | . Wolfe line search: 473 | F_new (tensor): loss function at new iterate 474 | g_new (tensor): gradient at new iterate 475 | t (float): final steplength 476 | ls_step (int): number of backtracks 477 | closure_eval (int): number of closure evaluations 478 | grad_eval (int): number of gradient evaluations 479 | desc_dir (bool): descent direction flag 480 | True: p_k is descent direction with respect to the line search 481 | function 482 | False: p_k is not a descent direction with respect to the line 483 | search function 484 | fail (bool): failure flag 485 | True: line search reached maximum number of iterations, failed 486 | False: line search succeeded 487 | 488 | Notes: 489 | . If encountering line search failure in the deterministic setting, one 490 | should try increasing the maximum number of line search steps max_ls. 491 | 492 | """ 493 | 494 | if options is None: 495 | options = {} 496 | assert len(self.param_groups) == 1 497 | 498 | # load parameter options 499 | group = self.param_groups[0] 500 | lr = group['lr'] 501 | line_search = group['line_search'] 502 | dtype = group['dtype'] 503 | debug = group['debug'] 504 | 505 | # variables cached in state (for tracing) 506 | state = self.state['global_state'] 507 | d = state.get('d') 508 | t = state.get('t') 509 | prev_flat_grad = state.get('prev_flat_grad') 510 | Bs = state.get('Bs') 511 | 512 | # keep track of nb of iterations 513 | state['n_iter'] += 1 514 | 515 | # set search direction 516 | d = p_k 517 | 518 | # modify previous gradient 519 | if prev_flat_grad is None: 520 | prev_flat_grad = g_Ok.clone() 521 | else: 522 | prev_flat_grad.copy_(g_Ok) 523 | 524 | # set initial step size 525 | t = lr 526 | 527 | # closure evaluation counter 528 | closure_eval = 0 529 | 530 | if g_Sk is None: 531 | g_Sk = g_Ok.clone() 532 | 533 | # perform Armijo backtracking line search 534 | if line_search == 'Armijo': 535 | 536 | # load options 537 | if options: 538 | if 'closure' not in options.keys(): 539 | raise ValueError('closure option not specified.') 540 | else: 541 | closure = options['closure'] 542 | 543 | if 'gtd' not in options.keys(): 544 | gtd = g_Sk.dot(d) 545 | else: 546 | gtd = options['gtd'] 547 | 548 | if 'current_loss' not in options.keys(): 549 | F_k = closure() 550 | closure_eval += 1 551 | else: 552 | F_k = options['current_loss'] 553 | 554 | if 'eta' not in options.keys(): 555 | eta = 2 556 | elif options['eta'] <= 0: 557 | raise ValueError('Invalid eta; must be positive.') 558 | else: 559 | eta = options['eta'] 560 | 561 | if 'c1' not in options.keys(): 562 | c1 = 1e-4 563 | elif options['c1'] >= 1 or options['c1'] <= 0: 564 | raise ValueError('Invalid c1; must be strictly between 0 and 1.') 565 | else: 566 | c1 = options['c1'] 567 | 568 | if 'max_ls' not in options.keys(): 569 | max_ls = 10 570 | elif options['max_ls'] <= 0: 571 | raise ValueError('Invalid max_ls; must be positive.') 572 | else: 573 | max_ls = options['max_ls'] 574 | 575 | if 'interpolate' not in options.keys(): 576 | interpolate = True 577 | else: 578 | interpolate = options['interpolate'] 579 | 580 | if 'inplace' not in options.keys(): 581 | inplace = True 582 | else: 583 | inplace = options['inplace'] 584 | 585 | if 'ls_debug' not in options.keys(): 586 | ls_debug = False 587 | else: 588 | ls_debug = options['ls_debug'] 589 | 590 | else: 591 | raise ValueError('Options are not specified; need closure evaluating function.') 592 | 593 | # initialize values 594 | if interpolate: 595 | if torch.cuda.is_available(): 596 | F_prev = torch.tensor(np.nan, dtype=dtype).cuda() 597 | else: 598 | F_prev = torch.tensor(np.nan, dtype=dtype) 599 | 600 | ls_step = 0 601 | t_prev = 0 # old steplength 602 | fail = False # failure flag 603 | 604 | # begin print for debug mode 605 | if ls_debug: 606 | print('==================================== Begin Armijo line search ===================================') 607 | print('F(x): %.8e g*d: %.8e' % (F_k, gtd)) 608 | 609 | # check if search direction is descent direction 610 | if gtd >= 0: 611 | desc_dir = False 612 | if debug: 613 | print('Not a descent direction!') 614 | else: 615 | desc_dir = True 616 | 617 | # store values if not in-place 618 | if not inplace: 619 | current_params = self._copy_params() 620 | 621 | # update and evaluate at new point 622 | self._add_update(t, d) 623 | F_new = closure() 624 | closure_eval += 1 625 | 626 | # print info if debugging 627 | if ls_debug: 628 | print('LS Step: %d t: %.8e F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e' 629 | % (ls_step, t, F_new, F_k + c1 * t * gtd, F_k)) 630 | 631 | # check Armijo condition 632 | while F_new > F_k + c1 * t * gtd or not is_legal(F_new): 633 | 634 | # check if maximum number of iterations reached 635 | if ls_step >= max_ls: 636 | if inplace: 637 | self._add_update(-t, d) 638 | else: 639 | self._load_params(current_params) 640 | 641 | t = 0 642 | F_new = closure() 643 | closure_eval += 1 644 | fail = True 645 | break 646 | 647 | else: 648 | # store current steplength 649 | t_new = t 650 | 651 | # compute new steplength 652 | 653 | # if first step or not interpolating, then multiply by factor 654 | if ls_step == 0 or not interpolate or not is_legal(F_new): 655 | t = t / eta 656 | 657 | # if second step, use function value at new point along with 658 | # gradient and function at current iterate 659 | elif ls_step == 1 or not is_legal(F_prev): 660 | t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan]])) 661 | 662 | # otherwise, use function values at new point, previous point, 663 | # and gradient and function at current iterate 664 | else: 665 | t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan], 666 | [t_prev, F_prev.item(), np.nan]])) 667 | 668 | # if values are too extreme, adjust t 669 | if interpolate: 670 | if t < 1e-3 * t_new: 671 | t = 1e-3 * t_new 672 | elif t > 0.6 * t_new: 673 | t = 0.6 * t_new 674 | 675 | # store old point 676 | F_prev = F_new 677 | t_prev = t_new 678 | 679 | # update iterate and reevaluate 680 | if inplace: 681 | self._add_update(t - t_new, d) 682 | else: 683 | self._load_params(current_params) 684 | self._add_update(t, d) 685 | 686 | F_new = closure() 687 | closure_eval += 1 688 | ls_step += 1 # iterate 689 | 690 | # print info if debugging 691 | if ls_debug: 692 | print('LS Step: %d t: %.8e F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e' 693 | % (ls_step, t, F_new, F_k + c1 * t * gtd, F_k)) 694 | 695 | # store Bs 696 | if Bs is None: 697 | Bs = (g_Sk.mul(-t)).clone() 698 | else: 699 | Bs.copy_(g_Sk.mul(-t)) 700 | 701 | # print final steplength 702 | if ls_debug: 703 | print('Final Steplength:', t) 704 | print('===================================== End Armijo line search ====================================') 705 | 706 | state['d'] = d 707 | state['prev_flat_grad'] = prev_flat_grad 708 | state['t'] = t 709 | state['Bs'] = Bs 710 | state['fail'] = fail 711 | 712 | return F_new, t, ls_step, closure_eval, desc_dir, fail 713 | 714 | # perform weak Wolfe line search 715 | elif line_search == 'Wolfe': 716 | 717 | # load options 718 | if options: 719 | if 'closure' not in options.keys(): 720 | raise ValueError('closure option not specified.') 721 | else: 722 | closure = options['closure'] 723 | 724 | if 'current_loss' not in options.keys(): 725 | F_k = closure() 726 | closure_eval += 1 727 | else: 728 | F_k = options['current_loss'] 729 | 730 | if 'gtd' not in options.keys(): 731 | gtd = g_Sk.dot(d) 732 | else: 733 | gtd = options['gtd'] 734 | 735 | if 'eta' not in options.keys(): 736 | eta = 2 737 | elif options['eta'] <= 1: 738 | raise ValueError('Invalid eta; must be greater than 1.') 739 | else: 740 | eta = options['eta'] 741 | 742 | if 'c1' not in options.keys(): 743 | c1 = 1e-4 744 | elif options['c1'] >= 1 or options['c1'] <= 0: 745 | raise ValueError('Invalid c1; must be strictly between 0 and 1.') 746 | else: 747 | c1 = options['c1'] 748 | 749 | if 'c2' not in options.keys(): 750 | c2 = 0.9 751 | elif options['c2'] >= 1 or options['c2'] <= 0: 752 | raise ValueError('Invalid c2; must be strictly between 0 and 1.') 753 | elif options['c2'] <= c1: 754 | raise ValueError('Invalid c2; must be strictly larger than c1.') 755 | else: 756 | c2 = options['c2'] 757 | 758 | if 'max_ls' not in options.keys(): 759 | max_ls = 10 760 | elif options['max_ls'] <= 0: 761 | raise ValueError('Invalid max_ls; must be positive.') 762 | else: 763 | max_ls = options['max_ls'] 764 | 765 | if 'interpolate' not in options.keys(): 766 | interpolate = True 767 | else: 768 | interpolate = options['interpolate'] 769 | 770 | if 'inplace' not in options.keys(): 771 | inplace = True 772 | else: 773 | inplace = options['inplace'] 774 | 775 | if 'ls_debug' not in options.keys(): 776 | ls_debug = False 777 | else: 778 | ls_debug = options['ls_debug'] 779 | 780 | else: 781 | raise ValueError('Options are not specified; need closure evaluating function.') 782 | 783 | # initialize counters 784 | ls_step = 0 785 | grad_eval = 0 # tracks gradient evaluations 786 | t_prev = 0 # old steplength 787 | 788 | # initialize bracketing variables and flag 789 | alpha = 0 790 | beta = float('Inf') 791 | fail = False 792 | 793 | # initialize values for line search 794 | if interpolate: 795 | F_a = F_k 796 | g_a = gtd 797 | 798 | if torch.cuda.is_available(): 799 | F_b = torch.tensor(np.nan, dtype=dtype).cuda() 800 | g_b = torch.tensor(np.nan, dtype=dtype).cuda() 801 | else: 802 | F_b = torch.tensor(np.nan, dtype=dtype) 803 | g_b = torch.tensor(np.nan, dtype=dtype) 804 | 805 | # begin print for debug mode 806 | if ls_debug: 807 | print('==================================== Begin Wolfe line search ====================================') 808 | print('F(x): %.8e g*d: %.8e' % (F_k, gtd)) 809 | 810 | # check if search direction is descent direction 811 | if gtd >= 0: 812 | desc_dir = False 813 | if debug: 814 | print('Not a descent direction!') 815 | else: 816 | desc_dir = True 817 | 818 | # store values if not in-place 819 | if not inplace: 820 | current_params = self._copy_params() 821 | 822 | # update and evaluate at new point 823 | self._add_update(t, d) 824 | F_new = closure() 825 | closure_eval += 1 826 | 827 | # main loop 828 | while True: 829 | 830 | # check if maximum number of line search steps have been reached 831 | if ls_step >= max_ls: 832 | if inplace: 833 | self._add_update(-t, d) 834 | else: 835 | self._load_params(current_params) 836 | 837 | t = 0 838 | F_new = closure() 839 | F_new.backward() 840 | g_new = self._gather_flat_grad() 841 | closure_eval += 1 842 | grad_eval += 1 843 | fail = True 844 | break 845 | 846 | # print info if debugging 847 | if ls_debug: 848 | print('LS Step: %d t: %.8e alpha: %.8e beta: %.8e' 849 | % (ls_step, t, alpha, beta)) 850 | print('Armijo: F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e' 851 | % (F_new, F_k + c1 * t * gtd, F_k)) 852 | 853 | # check Armijo condition 854 | if F_new > F_k + c1 * t * gtd: 855 | 856 | # set upper bound 857 | beta = t 858 | t_prev = t 859 | 860 | # update interpolation quantities 861 | if interpolate: 862 | F_b = F_new 863 | if torch.cuda.is_available(): 864 | g_b = torch.tensor(np.nan, dtype=dtype).cuda() 865 | else: 866 | g_b = torch.tensor(np.nan, dtype=dtype) 867 | 868 | else: 869 | 870 | # compute gradient 871 | F_new.backward() 872 | g_new = self._gather_flat_grad() 873 | grad_eval += 1 874 | gtd_new = g_new.dot(d) 875 | 876 | # print info if debugging 877 | if ls_debug: 878 | print('Wolfe: g(x+td)*d: %.8e c2*g*d: %.8e gtd: %.8e' 879 | % (gtd_new, c2 * gtd, gtd)) 880 | 881 | # check curvature condition 882 | if gtd_new < c2 * gtd: 883 | 884 | # set lower bound 885 | alpha = t 886 | t_prev = t 887 | 888 | # update interpolation quantities 889 | if interpolate: 890 | F_a = F_new 891 | g_a = gtd_new 892 | 893 | else: 894 | break 895 | 896 | # compute new steplength 897 | 898 | # if first step or not interpolating, then bisect or multiply by factor 899 | if not interpolate or not is_legal(F_b): 900 | if beta == float('Inf'): 901 | t = eta * t 902 | else: 903 | t = (alpha + beta) / 2.0 904 | 905 | # otherwise interpolate between a and b 906 | else: 907 | t = polyinterp(np.array([[alpha, F_a.item(), g_a.item()], [beta, F_b.item(), g_b.item()]])) 908 | 909 | # if values are too extreme, adjust t 910 | if beta == float('Inf'): 911 | if t > 2 * eta * t_prev: 912 | t = 2 * eta * t_prev 913 | elif t < eta * t_prev: 914 | t = eta * t_prev 915 | else: 916 | if t < alpha + 0.2 * (beta - alpha): 917 | t = alpha + 0.2 * (beta - alpha) 918 | elif t > (beta - alpha) / 2.0: 919 | t = (beta - alpha) / 2.0 920 | 921 | # if we obtain nonsensical value from interpolation 922 | if t <= 0: 923 | t = (beta - alpha) / 2.0 924 | 925 | # update parameters 926 | if inplace: 927 | self._add_update(t - t_prev, d) 928 | else: 929 | self._load_params(current_params) 930 | self._add_update(t, d) 931 | 932 | # evaluate closure 933 | F_new = closure() 934 | closure_eval += 1 935 | ls_step += 1 936 | 937 | # store Bs 938 | if Bs is None: 939 | Bs = (g_Sk.mul(-t)).clone() 940 | else: 941 | Bs.copy_(g_Sk.mul(-t)) 942 | 943 | # print final steplength 944 | if ls_debug: 945 | print('Final Steplength:', t) 946 | print('===================================== End Wolfe line search =====================================') 947 | 948 | state['d'] = d 949 | state['prev_flat_grad'] = prev_flat_grad 950 | state['t'] = t 951 | state['Bs'] = Bs 952 | state['fail'] = fail 953 | 954 | return F_new, g_new, t, ls_step, closure_eval, grad_eval, desc_dir, fail 955 | 956 | else: 957 | 958 | # perform update 959 | self._add_update(t, d) 960 | 961 | # store Bs 962 | if Bs is None: 963 | Bs = (g_Sk.mul(-t)).clone() 964 | else: 965 | Bs.copy_(g_Sk.mul(-t)) 966 | 967 | state['d'] = d 968 | state['prev_flat_grad'] = prev_flat_grad 969 | state['t'] = t 970 | state['Bs'] = Bs 971 | state['fail'] = False 972 | 973 | return t 974 | 975 | def step(self, p_k, g_Ok, g_Sk=None, options={}): 976 | return self._step(p_k, g_Ok, g_Sk, options) 977 | 978 | 979 | class FullBatchLBFGS(LBFGS): 980 | """ 981 | Implements full-batch or deterministic L-BFGS algorithm. Compatible with 982 | Powell damping. Can be used when evaluating a deterministic function and 983 | gradient. Wraps the LBFGS optimizer. Performs the two-loop recursion, 984 | updating, and curvature updating in a single step. 985 | 986 | Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere 987 | Last edited 11/15/18. 988 | 989 | Warnings: 990 | . Does not support per-parameter options and parameter groups. 991 | . All parameters have to be on a single device. 992 | 993 | Args: 994 | lr (float): steplength or learning rate (default: 1) 995 | history_size (int): update history size (default: 10) 996 | line_search (str): designates line search to use (default: 'Wolfe') 997 | Options: 998 | 'None': uses steplength designated in algorithm 999 | 'Armijo': uses Armijo backtracking line search 1000 | 'Wolfe': uses Armijo-Wolfe bracketing line search 1001 | dtype: data type (default: torch.float) 1002 | debug (bool): debugging mode 1003 | 1004 | """ 1005 | 1006 | def __init__(self, params, lr=1, history_size=10, line_search='Wolfe', 1007 | dtype=torch.float, debug=False): 1008 | super(FullBatchLBFGS, self).__init__( 1009 | params, lr, history_size, line_search, dtype, debug, 1010 | ) 1011 | 1012 | def step(self, options=None): 1013 | """ 1014 | Performs a single optimization step. 1015 | 1016 | Args: 1017 | options (dict): contains options for performing line search (default: None) 1018 | 1019 | General Options: 1020 | 'eps' (float): constant for curvature pair rejection or damping (default: 1e-2) 1021 | 'damping' (bool): flag for using Powell damping (default: False) 1022 | 1023 | Options for Armijo backtracking line search: 1024 | 'closure' (callable): reevaluates model and returns function value 1025 | 'current_loss' (tensor): objective value at current iterate (default: F(x_k)) 1026 | 'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd) 1027 | 'eta' (tensor): factor for decreasing steplength > 0 (default: 2) 1028 | 'c1' (tensor): sufficient decrease constant in (0, 1) (default: 1e-4) 1029 | 'max_ls' (int): maximum number of line search steps permitted (default: 10) 1030 | 'interpolate' (bool): flag for using interpolation (default: True) 1031 | 'inplace' (bool): flag for inplace operations (default: True) 1032 | 'ls_debug' (bool): debugging mode for line search 1033 | 1034 | Options for Wolfe line search: 1035 | 'closure' (callable): reevaluates model and returns function value 1036 | 'current_loss' (tensor): objective value at current iterate (default: F(x_k)) 1037 | 'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd) 1038 | 'eta' (float): factor for extrapolation (default: 2) 1039 | 'c1' (float): sufficient decrease constant in (0, 1) (default: 1e-4) 1040 | 'c2' (float): curvature condition constant in (0, 1) (default: 0.9) 1041 | 'max_ls' (int): maximum number of line search steps permitted (default: 10) 1042 | 'interpolate' (bool): flag for using interpolation (default: True) 1043 | 'inplace' (bool): flag for inplace operations (default: True) 1044 | 'ls_debug' (bool): debugging mode for line search 1045 | 1046 | Outputs (depends on line search): 1047 | . No line search: 1048 | t (float): steplength 1049 | . Armijo backtracking line search: 1050 | F_new (tensor): loss function at new iterate 1051 | t (tensor): final steplength 1052 | ls_step (int): number of backtracks 1053 | closure_eval (int): number of closure evaluations 1054 | desc_dir (bool): descent direction flag 1055 | True: p_k is descent direction with respect to the line search 1056 | function 1057 | False: p_k is not a descent direction with respect to the line 1058 | search function 1059 | fail (bool): failure flag 1060 | True: line search reached maximum number of iterations, failed 1061 | False: line search succeeded 1062 | . Wolfe line search: 1063 | F_new (tensor): loss function at new iterate 1064 | g_new (tensor): gradient at new iterate 1065 | t (float): final steplength 1066 | ls_step (int): number of backtracks 1067 | closure_eval (int): number of closure evaluations 1068 | grad_eval (int): number of gradient evaluations 1069 | desc_dir (bool): descent direction flag 1070 | True: p_k is descent direction with respect to the line search 1071 | function 1072 | False: p_k is not a descent direction with respect to the line 1073 | search function 1074 | fail (bool): failure flag 1075 | True: line search reached maximum number of iterations, failed 1076 | False: line search succeeded 1077 | 1078 | Notes: 1079 | . If encountering line search failure in the deterministic setting, one 1080 | should try increasing the maximum number of line search steps max_ls. 1081 | 1082 | """ 1083 | 1084 | # load options for damping and eps 1085 | if 'damping' not in options.keys(): 1086 | damping = False 1087 | else: 1088 | damping = options['damping'] 1089 | 1090 | if 'eps' not in options.keys(): 1091 | eps = 1e-2 1092 | else: 1093 | eps = options['eps'] 1094 | 1095 | # gather gradient 1096 | grad = self._gather_flat_grad() 1097 | 1098 | # update curvature if after 1st iteration 1099 | state = self.state['global_state'] 1100 | if state['n_iter'] > 0: 1101 | self.curvature_update(grad, eps, damping) 1102 | 1103 | # compute search direction 1104 | p = self.two_loop_recursion(-grad) 1105 | 1106 | # take step 1107 | return self._step(p, grad, options=options) 1108 | --------------------------------------------------------------------------------