├── .travis.yml ├── LICENSE ├── README.md ├── codecov.yml ├── examples └── example_lc_extrapolation.py ├── notebooks ├── bohamiann_example.ipynb ├── dngo_example.ipynb └── sampler.ipynb ├── pybnn ├── __init__.py ├── base_model.py ├── bayesian_linear_regression.py ├── bohamiann.py ├── dngo.py ├── lc_extrapolation │ ├── __init__.py │ ├── curvefunctions.py │ ├── curvemodels.py │ └── learning_curves.py ├── lcnet.py ├── multi_task_bohamiann.py ├── priors.py ├── sampler │ ├── __init__.py │ ├── adaptive_sghmc.py │ ├── preconditioned_sgld.py │ ├── sghmc.py │ └── sgld.py └── util │ ├── __init__.py │ ├── infinite_dataloader.py │ ├── layers.py │ └── normalization.py ├── requirements.txt ├── setup.py └── test ├── __init__.py ├── test_bohamiann.py ├── test_dngo.py ├── test_lcnet.py ├── test_mtbohamiann.py └── test_normalization.py /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | os: 3 | - linux 4 | python: 5 | - "3.5" 6 | install: 7 | - pip install -r requirements.txt 8 | - python setup.py install 9 | - pip install codecov 10 | - pip install pytest pytest-cov 11 | 12 | script: 13 | - pytest --cov=./ 14 | after_success: 15 | - codecov 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, ML4AAD 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/automl/pybnn.svg?branch=master)](https://travis-ci.org/automl/pybnn) 2 | [![codecov](https://codecov.io/gh/automl/pybnn/branch/master/graph/badge.svg)](https://codecov.io/gh/automl/pybnn) 3 | [![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://github.com/automl/pybnn/blob/master/LICENSE) 4 | 5 | # pybnn 6 | Bayesian neural networks for Bayesian optimization. 7 | 8 | It contains implementations for methods described in the following papers: 9 | - [Scalable Bayesian Optimization Using Deep Neural Networks](https://arxiv.org/pdf/1502.05700.pdf) (DNGO) 10 | - [Bayesian Optimization With Robust Bayesian Neural Networks](https://ml.informatik.uni-freiburg.de/papers/16-NIPS-BOHamiANN.pdf) (BOHAMIANN) 11 | - [Learning Curve Prediction With Bayesian Neural Networks](http://ml.informatik.uni-freiburg.de/papers/17-ICLR-LCNet.pdf) (LC-Net) 12 | 13 | # Installation 14 | 15 | git clone https://github.com/automl/pybnn.git 16 | cd pybnn 17 | python setup.py install 18 | 19 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | notify: 3 | require_ci_to_pass: yes 4 | 5 | coverage: 6 | precision: 2 7 | round: down 8 | range: "70...100" 9 | 10 | status: 11 | project: yes 12 | patch: yes 13 | changes: no 14 | 15 | parsers: 16 | gcov: 17 | branch_detection: 18 | conditional: yes 19 | loop: yes 20 | method: no 21 | macro: no 22 | 23 | comment: 24 | layout: "header, diff" 25 | behavior: default 26 | require_changes: no 27 | 28 | ignore: 29 | - "pybnn/notebooks/.*" 30 | - "pybnn/setup.py" 31 | - "pybnn/test/*" 32 | -------------------------------------------------------------------------------- /examples/example_lc_extrapolation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import matplotlib.pyplot as plt 4 | from pybnn.lc_extrapolation.learning_curves import MCMCCurveModelCombination 5 | 6 | 7 | observed = 40 8 | n_epochs = 100 9 | 10 | t_idx = np.arange(1, observed+1) 11 | t_idx_full = np.arange(1, n_epochs+1) 12 | 13 | 14 | def toy_example(t, a, b): 15 | return (10 + a * np.log(b * t + 1e-8)) / 10. # + 10e-3 * np.random.rand() 16 | 17 | 18 | a = np.random.rand() 19 | b = np.random.rand() 20 | lc = [toy_example(t / n_epochs, a, b) for t in t_idx_full] 21 | 22 | model = MCMCCurveModelCombination(n_epochs + 1, 23 | nwalkers=50, 24 | nsamples=800, 25 | burn_in=500, 26 | recency_weighting=False, 27 | soft_monotonicity_constraint=False, 28 | monotonicity_constraint=True, 29 | initial_model_weight_ml_estimate=True) 30 | st = time.time() 31 | model.fit(t_idx, lc[:observed]) 32 | print("Training time: %.2f" % (time.time() - st)) 33 | 34 | 35 | st = time.time() 36 | p_greater = model.posterior_prob_x_greater_than(n_epochs + 1, .5) 37 | print("Prediction time: %.2f" % (time.time() - st)) 38 | 39 | m = np.zeros([n_epochs]) 40 | s = np.zeros([n_epochs]) 41 | 42 | for i in range(n_epochs): 43 | p = model.predictive_distribution(i+1) 44 | m[i] = np.mean(p) 45 | s[i] = np.std(p) 46 | 47 | mean_mcmc = m[-1] 48 | std_mcmc = s[-1] 49 | 50 | plt.plot(t_idx_full, m, color="purple", label="LC-Extrapolation") 51 | plt.fill_between(t_idx_full, m + s, m - s, alpha=0.2, color="purple") 52 | plt.plot(t_idx_full, lc) 53 | 54 | plt.xlim(1, n_epochs) 55 | plt.legend() 56 | plt.xlabel("Number of epochs") 57 | plt.ylabel("Validation error") 58 | plt.axvline(observed, linestyle="--", color="black") 59 | plt.show() -------------------------------------------------------------------------------- /pybnn/__init__.py: -------------------------------------------------------------------------------- 1 | from pybnn.dngo import DNGO 2 | from pybnn.bayesian_linear_regression import BayesianLinearRegression 3 | from pybnn.base_model import BaseModel 4 | -------------------------------------------------------------------------------- /pybnn/base_model.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import numpy as np 3 | 4 | 5 | class BaseModel(object): 6 | __metaclass__ = abc.ABCMeta 7 | 8 | def __init__(self): 9 | """ 10 | Abstract base class for all models 11 | """ 12 | self.X = None 13 | self.y = None 14 | 15 | @abc.abstractmethod 16 | def train(self, X, y): 17 | """ 18 | Trains the model on the provided data. 19 | 20 | Parameters 21 | ---------- 22 | X: np.ndarray (N, D) 23 | Input data points. The dimensionality of X is (N, D), 24 | with N as the number of points and D is the number of input dimensions. 25 | y: np.ndarray (N,) 26 | The corresponding target values of the input data points. 27 | """ 28 | pass 29 | 30 | def update(self, X, y): 31 | """ 32 | Update the model with the new additional data. Override this function if your 33 | model allows to do something smarter than simple retraining 34 | 35 | Parameters 36 | ---------- 37 | X: np.ndarray (N, D) 38 | Input data points. The dimensionality of X is (N, D), 39 | with N as the number of points and D is the number of input dimensions. 40 | y: np.ndarray (N,) 41 | The corresponding target values of the input data points. 42 | """ 43 | X = np.append(self.X, X, axis=0) 44 | y = np.append(self.y, y, axis=0) 45 | self.train(X, y) 46 | 47 | @abc.abstractmethod 48 | def predict(self, X_test): 49 | """ 50 | Predicts for a given set of test data points the mean and variance of its target values 51 | 52 | Parameters 53 | ---------- 54 | X_test: np.ndarray (N, D) 55 | N Test data points with input dimensions D 56 | 57 | Returns 58 | ---------- 59 | mean: ndarray (N,) 60 | Predictive mean of the test data points 61 | var: ndarray (N,) 62 | Predictive variance of the test data points 63 | """ 64 | pass 65 | 66 | def _check_shapes_train(func): 67 | def func_wrapper(self, X, y, *args, **kwargs): 68 | assert X.shape[0] == y.shape[0] 69 | assert len(X.shape) == 2 70 | assert len(y.shape) == 1 71 | return func(self, X, y, *args, **kwargs) 72 | return func_wrapper 73 | 74 | def _check_shapes_predict(func): 75 | def func_wrapper(self, X, *args, **kwargs): 76 | assert len(X.shape) == 2 77 | return func(self, X, *args, **kwargs) 78 | 79 | return func_wrapper 80 | 81 | def get_json_data(self): 82 | """ 83 | Json getter function' 84 | 85 | Returns 86 | ---------- 87 | dictionary 88 | """ 89 | json_data = {'X': self.X if self.X is None else self.X.tolist(), 90 | 'y': self.y if self.y is None else self.y.tolist(), 91 | 'hyperparameters': ""} 92 | return json_data 93 | 94 | def get_incumbent(self): 95 | """ 96 | Returns the best observed point and its function value 97 | 98 | Returns 99 | ---------- 100 | incumbent: ndarray (D,) 101 | current incumbent 102 | incumbent_value: ndarray (N,) 103 | the observed value of the incumbent 104 | """ 105 | best_idx = np.argmin(self.y) 106 | return self.X[best_idx], self.y[best_idx] 107 | -------------------------------------------------------------------------------- /pybnn/bayesian_linear_regression.py: -------------------------------------------------------------------------------- 1 | import emcee 2 | import logging 3 | import numpy as np 4 | 5 | from scipy import optimize 6 | from scipy import stats 7 | 8 | from pybnn.base_model import BaseModel 9 | 10 | 11 | def linear_basis_func(x): 12 | return np.append(x, np.ones([x.shape[0], 1]), axis=1) 13 | 14 | 15 | def quadratic_basis_func(x): 16 | x = np.append(x ** 2, x, axis=1) 17 | return np.append(x, np.ones([x.shape[0], 1]), axis=1) 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class Prior(object): 24 | 25 | def __init__(self, rng=None): 26 | if rng is None: 27 | self.rng = np.random.RandomState(np.random.randint(0, 10000)) 28 | else: 29 | self.rng = rng 30 | 31 | def lnprob(self, theta): 32 | """ 33 | Compute the log probability for theta = [log alpha, log beta] 34 | :param theta: 35 | :return: log p(theta) 36 | """ 37 | lp = 0 38 | lp += stats.norm.pdf(theta[0], loc=0, scale=1) # log alpha 39 | lp += stats.norm.pdf(theta[1], loc=0, scale=1) # log sigma^2 40 | 41 | return lp 42 | 43 | def sample_from_prior(self, n_samples): 44 | p0 = np.zeros([n_samples, 2]) 45 | 46 | # Log alpha 47 | p0[:, 0] = self.rng.normal(loc=0, 48 | scale=1, 49 | size=n_samples) 50 | 51 | # Log sigma^2 52 | p0[:, 1] = self.rng.normal(loc=-3, 53 | scale=1, 54 | size=n_samples) 55 | return p0 56 | 57 | 58 | class BayesianLinearRegression(BaseModel): 59 | 60 | def __init__(self, alpha=1, beta=1000, basis_func=linear_basis_func, 61 | prior=None, do_mcmc=True, n_hypers=20, chain_length=2000, 62 | burnin_steps=2000, rng=None): 63 | """ 64 | Implementation of Bayesian linear regression. See chapter 3.3 of the book 65 | "Pattern Recognition and Machine Learning" by Bishop for more details. 66 | 67 | Parameters 68 | ---------- 69 | alpha: float 70 | Specifies the variance of the prior for the weights w 71 | beta : float 72 | Defines the inverse of the noise, i.e. beta = 1 / sigma^2 73 | basis_func : function 74 | Function handle to transfer the input with via basis functions 75 | (see the code above for an example) 76 | prior: Prior object 77 | Prior for alpha and beta. If set to None the default prior is used 78 | do_mcmc: bool 79 | If set to true different values for alpha and beta are sampled via MCMC from the marginal log likelihood 80 | Otherwise the marginal log likelihood is optimized with scipy fmin function 81 | n_hypers : int 82 | Number of samples for alpha and beta 83 | chain_length : int 84 | The chain length of the MCMC sampler 85 | burnin_steps: int 86 | The number of burnin steps before the sampling procedure starts 87 | rng: np.random.RandomState 88 | Random number generator 89 | """ 90 | 91 | if rng is None: 92 | self.rng = np.random.RandomState(np.random.randint(0, 10000)) 93 | else: 94 | self.rng = rng 95 | 96 | self.X = None 97 | self.y = None 98 | self.alpha = alpha 99 | self.beta = beta 100 | self.basis_func = basis_func 101 | if prior is None: 102 | self.prior = Prior(rng=self.rng) 103 | else: 104 | self.prior = prior 105 | self.do_mcmc = do_mcmc 106 | self.n_hypers = n_hypers 107 | self.chain_length = chain_length 108 | self.burned = False 109 | self.burnin_steps = burnin_steps 110 | self.models = None 111 | 112 | def marginal_log_likelihood(self, theta): 113 | """ 114 | Log likelihood of the data marginalised over the weights w. See chapter 3.5 of 115 | the book by Bishop of an derivation. 116 | 117 | Parameters 118 | ---------- 119 | theta: np.array(2,) 120 | The hyperparameter alpha and beta on a log scale 121 | 122 | Returns 123 | ------- 124 | float 125 | lnlikelihood + prior 126 | """ 127 | 128 | # Theta is on a log scale 129 | alpha = np.exp(theta[0]) 130 | beta = 1 / np.exp(theta[1]) 131 | 132 | D = self.X_transformed.shape[1] 133 | N = self.X_transformed.shape[0] 134 | 135 | A = beta * np.dot(self.X_transformed.T, self.X_transformed) 136 | A += np.eye(self.X_transformed.shape[1]) * alpha 137 | try: 138 | A_inv = np.linalg.inv(A) 139 | except np.linalg.linalg.LinAlgError: 140 | A_inv = np.linalg.inv(A + np.random.rand(A.shape[0], A.shape[1]) * 1e-8) 141 | 142 | 143 | m = beta * np.dot(A_inv, self.X_transformed.T) 144 | m = np.dot(m, self.y) 145 | 146 | mll = D / 2 * np.log(alpha) 147 | mll += N / 2 * np.log(beta) 148 | mll -= N / 2 * np.log(2 * np.pi) 149 | mll -= beta / 2. * np.linalg.norm(self.y - np.dot(self.X_transformed, m), 2) 150 | mll -= alpha / 2. * np.dot(m.T, m) 151 | mll -= 0.5 * np.log(np.linalg.det(A)) 152 | 153 | if self.prior is not None: 154 | mll += self.prior.lnprob(theta) 155 | 156 | return mll 157 | 158 | def negative_mll(self, theta): 159 | """ 160 | Returns the negative marginal log likelihood (for optimizing it with scipy). 161 | 162 | Parameters 163 | ---------- 164 | theta: np.array(2,) 165 | The hyperparameter alpha and beta on a log scale 166 | 167 | Returns 168 | ------- 169 | float 170 | negative lnlikelihood + prior 171 | """ 172 | return -self.marginal_log_likelihood(theta) 173 | 174 | @BaseModel._check_shapes_train 175 | def train(self, X, y, do_optimize=True): 176 | """ 177 | First optimized the hyperparameters if do_optimize is True and then computes 178 | the posterior distribution of the weights. See chapter 3.3 of the book by Bishop 179 | for more details. 180 | 181 | Parameters 182 | ---------- 183 | X: np.ndarray (N, D) 184 | Input data points. The dimensionality of X is (N, D), 185 | with N as the number of points and D is the number of features. 186 | y: np.ndarray (N,) 187 | The corresponding target values. 188 | do_optimize: boolean 189 | If set to true the hyperparameters are optimized otherwise 190 | the default hyperparameters are used. 191 | """ 192 | 193 | self.X = X 194 | 195 | if self.basis_func is not None: 196 | self.X_transformed = self.basis_func(X) 197 | else: 198 | self.X_transformed = self.X 199 | 200 | self.y = y 201 | 202 | if do_optimize: 203 | if self.do_mcmc: 204 | sampler = emcee.EnsembleSampler(self.n_hypers, 2, 205 | self.marginal_log_likelihood) 206 | 207 | # Do a burn-in in the first iteration 208 | if not self.burned: 209 | # Initialize the walkers by sampling from the prior 210 | self.p0 = self.prior.sample_from_prior(self.n_hypers) 211 | 212 | # Run MCMC sampling 213 | result = sampler.run_mcmc(self.p0, 214 | self.burnin_steps, 215 | rstate0=self.rng) 216 | self.p0 = result.coords 217 | 218 | self.burned = True 219 | 220 | # Start sampling 221 | pos = sampler.run_mcmc(self.p0, 222 | self.chain_length, 223 | rstate0=self.rng) 224 | 225 | # Save the current position, it will be the start point in 226 | # the next iteration 227 | self.p0 = pos.coords 228 | 229 | # Take the last samples from each walker 230 | self.hypers = np.exp(sampler.chain[:, -1]) 231 | else: 232 | # Optimize hyperparameters of the Bayesian linear regression 233 | res = optimize.fmin(self.negative_mll, self.rng.rand(2)) 234 | self.hypers = [[np.exp(res[0]), np.exp(res[1])]] 235 | 236 | else: 237 | self.hypers = [[self.alpha, self.beta]] 238 | 239 | self.models = [] 240 | for sample in self.hypers: 241 | alpha = sample[0] 242 | beta = sample[1] 243 | 244 | logger.debug("Alpha=%f ; Beta=%f" % (alpha, beta)) 245 | 246 | S_inv = beta * np.dot(self.X_transformed.T, self.X_transformed) 247 | S_inv += np.eye(self.X_transformed.shape[1]) * alpha 248 | try: 249 | S = np.linalg.inv(S_inv) 250 | except np.linalg.linalg.LinAlgError: 251 | S = np.linalg.inv(S_inv + np.random.rand(S_inv.shape[0], S_inv.shape[1]) * 1e-8) 252 | 253 | m = beta * np.dot(np.dot(S, self.X_transformed.T), self.y) 254 | 255 | self.models.append((m, S)) 256 | 257 | @BaseModel._check_shapes_predict 258 | def predict(self, X_test): 259 | r""" 260 | Returns the predictive mean and variance of the objective function at 261 | the given test points. 262 | 263 | Parameters 264 | ---------- 265 | X_test: np.ndarray (N, D) 266 | N input test points 267 | 268 | Returns 269 | ---------- 270 | np.array(N,) 271 | predictive mean 272 | np.array(N,) 273 | predictive variance 274 | 275 | """ 276 | if self.basis_func is not None: 277 | X_transformed = self.basis_func(X_test) 278 | else: 279 | X_transformed = X_test 280 | 281 | # Marginalise predictions over hyperparameters 282 | mu = np.zeros([len(self.hypers), X_transformed.shape[0]]) 283 | var = np.zeros([len(self.hypers), X_transformed.shape[0]]) 284 | 285 | for i, h in enumerate(self.hypers): 286 | mu[i] = np.dot(self.models[i][0].T, X_transformed.T) 287 | var[i] = 1. / h[1] + np.diag(np.dot(np.dot(X_transformed, self.models[i][1]), X_transformed.T)) 288 | 289 | m = mu.mean(axis=0) 290 | v = var.mean(axis=0) 291 | # Clip negative variances and set them to the smallest 292 | # positive float value 293 | if v.shape[0] == 1: 294 | v = np.clip(v, np.finfo(v.dtype).eps, np.inf) 295 | else: 296 | v = np.clip(v, np.finfo(v.dtype).eps, np.inf) 297 | v[np.where((v < np.finfo(v.dtype).eps) & (v > -np.finfo(v.dtype).eps))] = 0 298 | 299 | return m, v 300 | -------------------------------------------------------------------------------- /pybnn/bohamiann.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import typing 4 | from itertools import islice 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.utils.data as data_utils 10 | from scipy.stats import norm 11 | 12 | from pybnn.base_model import BaseModel 13 | from pybnn.priors import weight_prior, log_variance_prior 14 | from pybnn.sampler import AdaptiveSGHMC, SGLD, SGHMC, PreconditionedSGLD 15 | from pybnn.util.infinite_dataloader import infinite_dataloader 16 | from pybnn.util.layers import AppendLayer 17 | from pybnn.util.normalization import zero_mean_unit_var_denormalization, zero_mean_unit_var_normalization 18 | 19 | 20 | def get_default_network(input_dimensionality: int) -> torch.nn.Module: 21 | class Architecture(torch.nn.Module): 22 | def __init__(self, n_inputs, n_hidden=50): 23 | super(Architecture, self).__init__() 24 | self.fc1 = torch.nn.Linear(n_inputs, n_hidden) 25 | self.fc2 = torch.nn.Linear(n_hidden, n_hidden) 26 | self.fc3 = torch.nn.Linear(n_hidden, 1) 27 | self.log_std = AppendLayer(noise=1e-3) 28 | 29 | def forward(self, input): 30 | x = torch.tanh(self.fc1(input)) 31 | x = torch.tanh(self.fc2(x)) 32 | x = self.fc3(x) 33 | return self.log_std(x) 34 | 35 | return Architecture(n_inputs=input_dimensionality) 36 | 37 | 38 | def nll(input: torch.Tensor, target: torch.Tensor): 39 | """ 40 | computes the average negative log-likelihood (Gaussian) 41 | 42 | :param input: mean and variance predictions of the networks 43 | :param target: target values 44 | :return: negative log-likelihood 45 | """ 46 | batch_size = input.size(0) 47 | 48 | prediction_mean = input[:, 0].view((-1, 1)) 49 | log_prediction_variance = input[:, 1].view((-1, 1)) 50 | prediction_variance_inverse = 1. / (torch.exp(log_prediction_variance) + 1e-16) 51 | 52 | mean_squared_error = (target.view(-1, 1) - prediction_mean) ** 2 53 | 54 | log_likelihood = torch.sum( 55 | torch.sum(-mean_squared_error * (0.5 * prediction_variance_inverse) - 0.5 * log_prediction_variance, dim=1)) 56 | 57 | log_likelihood = log_likelihood / batch_size 58 | 59 | return -log_likelihood 60 | 61 | 62 | class Bohamiann(BaseModel): 63 | def __init__(self, 64 | get_network=get_default_network, 65 | normalize_input: bool = True, 66 | normalize_output: bool = True, 67 | sampling_method: str = "adaptive_sghmc", 68 | use_double_precision: bool = True, 69 | metrics=(nn.MSELoss,), 70 | likelihood_function=nll, 71 | print_every_n_steps=100, 72 | ) -> None: 73 | """ 74 | 75 | Bayesian Neural Networks use Bayesian methods to estimate the posterior 76 | distribution of a neural network's weights. This allows to also 77 | predict uncertainties for test points and thus makes Bayesian Neural 78 | Networks suitable for Bayesian optimization. 79 | This module uses stochastic gradient MCMC methods to sample 80 | from the posterior distribution. 81 | 82 | See [1] for more details. 83 | 84 | [1] J. T. Springenberg, A. Klein, S. Falkner, F. Hutter 85 | Bayesian Optimization with Robust Bayesian Neural Networks. 86 | In Advances in Neural Information Processing Systems 29 (2016). 87 | 88 | :param get_network: function handle that returns the archtiecture 89 | :param normalize_input: defines whether to normalize the inputs 90 | :param normalize_output: defines whether to normalize the outputs 91 | :param sampling_method: specifies the sampling strategy, 92 | options: {sgld, sghmc, adaptive_sghmc, preconditioned_sgld} 93 | :param use_double_precision: defines whether to use double or float precisions 94 | :param metrics: metrics to evaluate 95 | :param likelihood_function: function handle that computes the training loss 96 | :param print_every_n_steps: defines after how many the current loss is printed 97 | """ 98 | self.print_every_n_steps = print_every_n_steps 99 | self.metrics = metrics 100 | self.do_normalize_input = normalize_input 101 | self.do_normalize_output = normalize_output 102 | self.get_network = get_network 103 | self.is_trained = False 104 | self.use_double_precision = use_double_precision 105 | self.sampling_method = sampling_method 106 | self.sampled_weights = [] # type: typing.List[typing.Tuple[np.ndarray]] 107 | self.likelihood_function = likelihood_function 108 | self.sampler = None 109 | 110 | @property 111 | def network_weights(self) -> tuple: 112 | """ 113 | Extract current network weight values as `np.ndarray`. 114 | 115 | :return: Tuple containing current network weight values 116 | """ 117 | return tuple( 118 | np.asarray(parameter.data.clone().detach().numpy()) 119 | for parameter in self.model.parameters() 120 | ) 121 | 122 | @network_weights.setter 123 | def network_weights(self, weights: typing.List[np.ndarray]) -> None: 124 | """ 125 | Assign new `weights` to our neural networks parameters. 126 | 127 | :param weights: List of weight values to assign. 128 | Individual list elements must have shapes that match 129 | the network parameters with the same index in `self.network_weights`. 130 | """ 131 | logging.debug("Assigning new network weights") 132 | for parameter, sample in zip(self.model.parameters(), weights): 133 | parameter.copy_(torch.from_numpy(sample)) 134 | 135 | def train(self, x_train: np.ndarray, y_train: np.ndarray, 136 | num_steps: int = 13000, 137 | keep_every: int = 100, 138 | num_burn_in_steps: int = 3000, 139 | lr: float = 1e-2, 140 | batch_size=20, 141 | epsilon: float = 1e-10, 142 | mdecay: float = 0.05, 143 | continue_training: bool = False, 144 | verbose: bool = False, 145 | **kwargs): 146 | 147 | """ 148 | Train a BNN using input datapoints `x_train` with corresponding targets `y_train`. 149 | 150 | :param x_train: input training datapoints. 151 | :param y_train: input training targets. 152 | :param num_steps: Number of sampling steps to perform after burn-in is finished. 153 | In total, `num_steps // keep_every` network weights will be sampled. 154 | :param keep_every: Number of sampling steps (after burn-in) to perform before keeping a sample. 155 | In total, `num_steps // keep_every` network weights will be sampled. 156 | :param num_burn_in_steps: Number of burn-in steps to perform. 157 | This value is passed to the given `optimizer` if it supports special 158 | burn-in specific behavior. 159 | Networks sampled during burn-in are discarded. 160 | :param lr: learning rate 161 | :param batch_size: batch size 162 | :param epsilon: epsilon for numerical stability 163 | :param mdecay: momemtum decay 164 | :param continue_training: defines whether we want to continue from the last training run 165 | :param verbose: verbose output 166 | """ 167 | logging.debug("Training started.") 168 | start_time = time.time() 169 | 170 | num_datapoints, input_dimensionality = x_train.shape 171 | logging.debug( 172 | "Processing %d training datapoints " 173 | " with % dimensions each." % (num_datapoints, input_dimensionality) 174 | ) 175 | assert batch_size >= 1, "Invalid batch size. Batches must contain at least a single sample." 176 | assert len(y_train.shape) == 1 or (len(y_train.shape) == 2 and y_train.shape[ 177 | 1] == 1), "Targets need to be in vector format, i.e (N,) or (N,1)" 178 | 179 | if x_train.shape[0] < batch_size: 180 | logging.warning("Not enough datapoints to form a batch. Use all datapoints in each batch") 181 | batch_size = x_train.shape[0] 182 | 183 | self.X = x_train 184 | if len(y_train.shape) == 2: 185 | self.y = y_train[:, 0] 186 | else: 187 | self.y = y_train 188 | 189 | if self.do_normalize_input: 190 | logging.debug( 191 | "Normalizing training datapoints to " 192 | " zero mean and unit variance." 193 | ) 194 | x_train_, self.x_mean, self.x_std = self.normalize_input(x_train) 195 | if self.use_double_precision: 196 | x_train_ = torch.from_numpy(x_train_).double() 197 | else: 198 | x_train_ = torch.from_numpy(x_train_).float() 199 | else: 200 | if self.use_double_precision: 201 | x_train_ = torch.from_numpy(x_train).double() 202 | else: 203 | x_train_ = torch.from_numpy(x_train).float() 204 | 205 | if self.do_normalize_output: 206 | logging.debug("Normalizing training labels to zero mean and unit variance.") 207 | y_train_, self.y_mean, self.y_std = self.normalize_output(self.y) 208 | 209 | if self.use_double_precision: 210 | y_train_ = torch.from_numpy(y_train_).double() 211 | else: 212 | y_train_ = torch.from_numpy(y_train_).float() 213 | else: 214 | if self.use_double_precision: 215 | y_train_ = torch.from_numpy(y_train).double() 216 | else: 217 | y_train_ = torch.from_numpy(y_train).float() 218 | 219 | train_loader = infinite_dataloader( 220 | data_utils.DataLoader( 221 | data_utils.TensorDataset(x_train_, y_train_), 222 | batch_size=batch_size, 223 | shuffle=True 224 | ) 225 | ) 226 | 227 | if self.use_double_precision: 228 | dtype = np.float64 229 | else: 230 | dtype = np.float32 231 | 232 | if not continue_training: 233 | logging.debug("Clearing list of sampled weights.") 234 | 235 | self.sampled_weights.clear() 236 | if self.use_double_precision: 237 | self.model = self.get_network(input_dimensionality=input_dimensionality).double() 238 | else: 239 | self.model = self.get_network(input_dimensionality=input_dimensionality).float() 240 | 241 | if self.sampling_method == "adaptive_sghmc": 242 | self.sampler = AdaptiveSGHMC(self.model.parameters(), 243 | scale_grad=dtype(num_datapoints), 244 | num_burn_in_steps=num_burn_in_steps, 245 | lr=dtype(lr), 246 | mdecay=dtype(mdecay), 247 | epsilon=dtype(epsilon)) 248 | elif self.sampling_method == "sgld": 249 | self.sampler = SGLD(self.model.parameters(), 250 | lr=dtype(lr), 251 | scale_grad=num_datapoints) 252 | elif self.sampling_method == "preconditioned_sgld": 253 | self.sampler = PreconditionedSGLD(self.model.parameters(), 254 | lr=dtype(lr), 255 | num_train_points=num_datapoints) 256 | elif self.sampling_method == "sghmc": 257 | self.sampler = SGHMC(self.model.parameters(), 258 | scale_grad=dtype(num_datapoints), 259 | mdecay=dtype(mdecay), 260 | lr=dtype(lr)) 261 | 262 | batch_generator = islice(enumerate(train_loader), num_steps) 263 | 264 | for step, (x_batch, y_batch) in batch_generator: 265 | self.sampler.zero_grad() 266 | loss = self.likelihood_function(input=self.model(x_batch), target=y_batch) 267 | # Add prior. Note the gradient is computed by: g_prior + N/n sum_i grad_theta_xi see Eq 4 268 | # in Welling and Whye The 2011. Because of that we divide here by N=num of datapoints since 269 | # in the sample we rescale the gradient by N again 270 | loss -= log_variance_prior(self.model(x_batch)[:, 1].view((-1, 1))) / num_datapoints 271 | loss -= weight_prior(self.model.parameters(), dtype=dtype) / num_datapoints 272 | loss.backward() 273 | self.sampler.step() 274 | 275 | if verbose and step > 0 and step % self.print_every_n_steps == 0: 276 | 277 | # compute the training performance of the ensemble 278 | if len(self.sampled_weights) > 1: 279 | mu, var = self.predict(x_train) 280 | total_nll = -np.mean(norm.logpdf(y_train, loc=mu, scale=np.sqrt(var))) 281 | total_mse = np.mean((y_train - mu) ** 2) 282 | # in case we do not have an ensemble we compute the performance of the last weight sample 283 | else: 284 | f = self.model(x_train_) 285 | 286 | if self.do_normalize_output: 287 | mu = zero_mean_unit_var_denormalization(f[:, 0], self.y_mean, self.y_std).data.numpy() 288 | var = torch.exp(f[:, 1]) * self.y_std ** 2 289 | var = var.data.numpy() 290 | else: 291 | mu = f[:, 0].data.numpy() 292 | var = np.exp(f[:, 1].data.numpy()) 293 | total_nll = -np.mean(norm.logpdf(y_train, loc=mu, scale=np.sqrt(var))) 294 | total_mse = np.mean((y_train - mu) ** 2) 295 | 296 | t = time.time() - start_time 297 | 298 | if step < num_burn_in_steps: 299 | print("Step {:8d} : NLL = {:11.4e} MSE = {:.4e} " 300 | "Time = {:5.2f}".format(step, float(total_nll), 301 | float(total_mse), t)) 302 | 303 | if step > num_burn_in_steps: 304 | print("Step {:8d} : NLL = {:11.4e} MSE = {:.4e} " 305 | "Samples= {} Time = {:5.2f}".format(step, 306 | float(total_nll), 307 | float(total_mse), 308 | len(self.sampled_weights), t)) 309 | 310 | if step > num_burn_in_steps and (step - num_burn_in_steps) % keep_every == 0: 311 | weights = self.network_weights 312 | 313 | self.sampled_weights.append(weights) 314 | 315 | self.is_trained = True 316 | 317 | def train_and_evaluate(self, x_train: np.ndarray, y_train: np.ndarray, 318 | x_valid: np.ndarray, y_valid: np.ndarray, 319 | num_steps: int = 13000, 320 | validate_every_n_steps=1000, 321 | keep_every: int = 100, 322 | num_burn_in_steps: int = 3000, 323 | lr: float = 1e-2, 324 | epsilon: float = 1e-10, 325 | batch_size: int = 20, 326 | mdecay: float = 0.05, 327 | verbose=False): 328 | """ 329 | Train and validates the neural network 330 | 331 | :param x_train: input training datapoints. 332 | :param y_train: input training targets. 333 | :param x_valid: validation data points 334 | :param y_valid: valdiation targets 335 | :param num_steps: Number of sampling steps to perform after burn-in is finished. 336 | In total, `num_steps // keep_every` network weights will be sampled. 337 | :param validate_every_n_steps: 338 | :param keep_every: Number of sampling steps (after burn-in) to perform before keeping a sample. 339 | In total, `num_steps // keep_every` network weights will be sampled. 340 | :param num_burn_in_steps: Number of burn-in steps to perform. 341 | This value is passed to the given `optimizer` if it supports special 342 | burn-in specific behavior. 343 | Networks sampled during burn-in are discarded. 344 | :param lr: learning rate 345 | :param batch_size: batch size 346 | :param epsilon: epsilon for numerical stability 347 | :param mdecay: momemtum decay 348 | :param verbose: verbose output 349 | 350 | """ 351 | assert batch_size >= 1, "Invalid batch size. Batches must contain at least a single sample." 352 | 353 | if x_train.shape[0] < batch_size: 354 | logging.warning("Not enough datapoints to form a batch. Use all datapoints in each batch") 355 | batch_size = x_train.shape[0] 356 | 357 | # burn-in 358 | self.train(x_train, y_train, num_burn_in_steps=num_burn_in_steps, num_steps=num_burn_in_steps, 359 | lr=lr, epsilon=epsilon, mdecay=mdecay, verbose=verbose) 360 | 361 | learning_curve_mse = [] 362 | learning_curve_ll = [] 363 | n_steps = [] 364 | for i in range(num_steps // validate_every_n_steps): 365 | self.train(x_train, y_train, num_burn_in_steps=0, num_steps=validate_every_n_steps, 366 | lr=lr, epsilon=epsilon, mdecay=mdecay, verbose=verbose, keep_every=keep_every, 367 | continue_training=True, batch_size=batch_size) 368 | 369 | mu, var = self.predict(x_valid) 370 | 371 | ll = np.mean(norm.logpdf(y_valid, loc=mu, scale=np.sqrt(var))) 372 | mse = np.mean((y_valid - mu) ** 2) 373 | step = num_burn_in_steps + (i + 1) * validate_every_n_steps 374 | 375 | learning_curve_ll.append(ll) 376 | learning_curve_mse.append(mse) 377 | n_steps.append(step) 378 | 379 | if verbose: 380 | print("Validate : NLL = {:11.4e} MSE = {:.4e}".format(-ll, mse)) 381 | 382 | return n_steps, learning_curve_ll, learning_curve_mse 383 | 384 | def normalize_input(self, x, m=None, s=None): 385 | """ 386 | Normalizes input 387 | 388 | :param x: data 389 | :param m: mean 390 | :param s: standard deviation 391 | :return: normalized input 392 | """ 393 | 394 | return zero_mean_unit_var_normalization(x, m, s) 395 | 396 | def normalize_output(self, x, m=None, s=None): 397 | """ 398 | Normalizes output 399 | 400 | :param x: targets 401 | :param m: mean 402 | :param s: standard deviation 403 | :return: normalized targets 404 | """ 405 | return zero_mean_unit_var_normalization(x, m, s) 406 | 407 | def predict(self, x_test: np.ndarray, return_individual_predictions: bool = False): 408 | """ 409 | Predicts mean and variance for the given test point 410 | 411 | :param x_test: test datapoint 412 | :param return_individual_predictions: if True also the predictions of the individual models are returned 413 | :return: mean and variance 414 | """ 415 | x_test_ = np.asarray(x_test) 416 | 417 | if self.do_normalize_input: 418 | x_test_, *_ = self.normalize_input(x_test_, self.x_mean, self.x_std) 419 | 420 | def network_predict(x_test_, weights): 421 | with torch.no_grad(): 422 | self.network_weights = weights 423 | if self.use_double_precision: 424 | return self.model(torch.from_numpy(x_test_).double()).numpy() 425 | else: 426 | return self.model(torch.from_numpy(x_test_).float()).numpy() 427 | 428 | logging.debug("Predicting with %d networks." % len(self.sampled_weights)) 429 | network_outputs = np.array([ 430 | network_predict(x_test_, weights=weights) 431 | for weights in self.sampled_weights 432 | ]) 433 | 434 | mean_prediction = np.mean(network_outputs[:, :, 0], axis=0) 435 | # variance_prediction = np.mean((network_outputs[:, :, 0] - mean_prediction) ** 2, axis=0) 436 | # Total variance 437 | variance_prediction = np.mean((network_outputs[:, :, 0] - mean_prediction) ** 2 438 | + np.exp(network_outputs[:, :, 1]), axis=0) 439 | 440 | if self.do_normalize_output: 441 | 442 | mean_prediction = zero_mean_unit_var_denormalization( 443 | mean_prediction, self.y_mean, self.y_std 444 | ) 445 | variance_prediction *= self.y_std ** 2 446 | 447 | for i in range(len(network_outputs)): 448 | network_outputs[i] = zero_mean_unit_var_denormalization( 449 | network_outputs[i], self.y_mean, self.y_std 450 | ) 451 | 452 | if return_individual_predictions: 453 | return mean_prediction, variance_prediction, network_outputs[:, :, 0] 454 | 455 | return mean_prediction, variance_prediction 456 | 457 | def predict_single(self, x_test: np.ndarray, sample_index: int): 458 | """ 459 | Compute the prediction of a single weight sample 460 | 461 | :param x_test: test datapoint 462 | :param sample_index: specifies the index of the weight sample 463 | :return: mean and variance of the neural network 464 | """ 465 | x_test_ = np.asarray(x_test) 466 | 467 | if self.do_normalize_input: 468 | x_test_, *_ = self.normalize_input(x_test_, self.x_mean, self.x_std) 469 | 470 | def network_predict(x_test_, weights): 471 | with torch.no_grad(): 472 | self.network_weights = weights 473 | if self.use_double_precision: 474 | return self.model(torch.from_numpy(x_test_).double()).numpy() 475 | else: 476 | return self.model(torch.from_numpy(x_test_).float()).numpy() 477 | 478 | logging.debug("Predicting with %d networks." % len(self.sampled_weights)) 479 | function_value = np.array(network_predict(x_test_, weights=self.sampled_weights[sample_index])) 480 | 481 | if self.do_normalize_output: 482 | function_value = zero_mean_unit_var_denormalization( 483 | function_value, self.y_mean, self.y_std 484 | ) 485 | return function_value 486 | 487 | def f_gradient(self, x_test, weights): 488 | x_test_ = np.asarray(x_test) 489 | 490 | with torch.no_grad(): 491 | self.network_weights = weights 492 | 493 | if self.use_double_precision: 494 | x = torch.autograd.Variable(torch.from_numpy(x_test_[None, :]).double(), requires_grad=True) 495 | else: 496 | x = torch.autograd.Variable(torch.from_numpy(x_test_[None, :]).float(), requires_grad=True) 497 | 498 | if self.do_normalize_input: 499 | if self.use_double_precision: 500 | x_mean = torch.autograd.Variable(torch.from_numpy(self.x_mean).double(), requires_grad=False) 501 | x_std = torch.autograd.Variable(torch.from_numpy(self.x_std).double(), requires_grad=False) 502 | else: 503 | x_mean = torch.autograd.Variable(torch.from_numpy(self.x_mean).float(), requires_grad=False) 504 | x_std = torch.autograd.Variable(torch.from_numpy(self.x_std).float(), requires_grad=False) 505 | 506 | x_norm = (x - x_mean) / x_std 507 | m = self.model(x_norm)[0][0] 508 | else: 509 | m = self.model(x)[0][0] 510 | if self.do_normalize_output: 511 | 512 | if self.use_double_precision: 513 | y_mean = torch.autograd.Variable(torch.from_numpy(np.array([self.y_mean])).double(), 514 | requires_grad=False) 515 | y_std = torch.autograd.Variable(torch.from_numpy(np.array([self.y_std])).double(), requires_grad=False) 516 | 517 | else: 518 | y_mean = torch.autograd.Variable(torch.from_numpy(np.array([self.y_mean])).float(), requires_grad=False) 519 | y_std = torch.autograd.Variable(torch.from_numpy(np.array([self.y_std])).float(), requires_grad=False) 520 | 521 | m = m * y_std + y_mean 522 | 523 | m.backward() 524 | 525 | g = x.grad.data.numpy()[0, :] 526 | return g 527 | 528 | def predictive_mean_gradient(self, x_test: np.ndarray): 529 | 530 | # compute the individual gradients for each weight vector 531 | grads = np.array([self.f_gradient(x_test, weights=weights) for weights in self.sampled_weights]) 532 | 533 | # the gradient of the mean is mean of all individual gradients 534 | g = np.mean(grads, axis=0) 535 | 536 | return g 537 | 538 | def predictive_variance_gradient(self, x_test: np.ndarray): 539 | m, v, funcs = self.predict(x_test[None, :], return_individual_predictions=True) 540 | 541 | grads = np.array([self.f_gradient(x_test, weights=weights) for weights in self.sampled_weights]) 542 | 543 | dmdx = self.predictive_mean_gradient(x_test) 544 | 545 | g = np.mean([2 * (funcs[i] - m) * (grads[i] - dmdx) for i in range(len(self.sampled_weights))], axis=0) 546 | 547 | return g 548 | -------------------------------------------------------------------------------- /pybnn/dngo.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import numpy as np 4 | import emcee 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | 10 | from scipy import optimize 11 | 12 | from pybnn.base_model import BaseModel 13 | from pybnn.util.normalization import zero_mean_unit_var_normalization, zero_mean_unit_var_denormalization 14 | from pybnn.bayesian_linear_regression import BayesianLinearRegression, Prior 15 | 16 | 17 | class Net(nn.Module): 18 | def __init__(self, n_inputs, n_units=[50, 50, 50]): 19 | super(Net, self).__init__() 20 | self.fc1 = nn.Linear(n_inputs, n_units[0]) 21 | self.fc2 = nn.Linear(n_units[0], n_units[1]) 22 | self.fc3 = nn.Linear(n_units[1], n_units[2]) 23 | self.out = nn.Linear(n_units[2], 1) 24 | 25 | def forward(self, x): 26 | x = torch.tanh(self.fc1(x)) 27 | x = torch.tanh(self.fc2(x)) 28 | x = torch.tanh(self.fc3(x)) 29 | 30 | return self.out(x) 31 | 32 | def basis_funcs(self, x): 33 | x = torch.tanh(self.fc1(x)) 34 | x = torch.tanh(self.fc2(x)) 35 | x = torch.tanh(self.fc3(x)) 36 | return x 37 | 38 | 39 | class DNGO(BaseModel): 40 | 41 | def __init__(self, batch_size=10, num_epochs=500, 42 | learning_rate=0.01, 43 | adapt_epoch=5000, n_units_1=50, n_units_2=50, n_units_3=50, 44 | alpha=1.0, beta=1000, prior=None, do_mcmc=True, 45 | n_hypers=20, chain_length=2000, burnin_steps=2000, 46 | normalize_input=True, normalize_output=True, rng=None): 47 | """ 48 | Deep Networks for Global Optimization [1]. This module performs 49 | Bayesian Linear Regression with basis function extracted from a 50 | feed forward neural network. 51 | 52 | [1] J. Snoek, O. Rippel, K. Swersky, R. Kiros, N. Satish, 53 | N. Sundaram, M.~M.~A. Patwary, Prabhat, R.~P. Adams 54 | Scalable Bayesian Optimization Using Deep Neural Networks 55 | Proc. of ICML'15 56 | 57 | Parameters 58 | ---------- 59 | batch_size: int 60 | Batch size for training the neural network 61 | num_epochs: int 62 | Number of epochs for training 63 | learning_rate: float 64 | Initial learning rate for Adam 65 | adapt_epoch: int 66 | Defines after how many epochs the learning rate will be decayed by a factor 10 67 | n_units_1: int 68 | Number of units in layer 1 69 | n_units_2: int 70 | Number of units in layer 2 71 | n_units_3: int 72 | Number of units in layer 3 73 | alpha: float 74 | Hyperparameter of the Bayesian linear regression 75 | beta: float 76 | Hyperparameter of the Bayesian linear regression 77 | prior: Prior object 78 | Prior for alpa and beta. If set to None the default prior is used 79 | do_mcmc: bool 80 | If set to true different values for alpha and beta are sampled via MCMC from the marginal log likelihood 81 | Otherwise the marginal log likehood is optimized with scipy fmin function 82 | n_hypers : int 83 | Number of samples for alpha and beta 84 | chain_length : int 85 | The chain length of the MCMC sampler 86 | burnin_steps: int 87 | The number of burnin steps before the sampling procedure starts 88 | normalize_output : bool 89 | Zero mean unit variance normalization of the output values 90 | normalize_input : bool 91 | Zero mean unit variance normalization of the input values 92 | rng: np.random.RandomState 93 | Random number generator 94 | """ 95 | 96 | if rng is None: 97 | self.rng = np.random.RandomState(np.random.randint(0, 10000)) 98 | else: 99 | self.rng = rng 100 | 101 | self.X = None 102 | self.y = None 103 | self.network = None 104 | self.alpha = alpha 105 | self.beta = beta 106 | self.normalize_input = normalize_input 107 | self.normalize_output = normalize_output 108 | 109 | # MCMC hyperparameters 110 | self.do_mcmc = do_mcmc 111 | self.n_hypers = n_hypers 112 | self.chain_length = chain_length 113 | self.burned = False 114 | self.burnin_steps = burnin_steps 115 | if prior is None: 116 | self.prior = Prior(rng=self.rng) 117 | else: 118 | self.prior = prior 119 | 120 | # Network hyper parameters 121 | self.num_epochs = num_epochs 122 | self.batch_size = batch_size 123 | self.init_learning_rate = learning_rate 124 | 125 | self.n_units_1 = n_units_1 126 | self.n_units_2 = n_units_2 127 | self.n_units_3 = n_units_3 128 | self.adapt_epoch = adapt_epoch 129 | self.network = None 130 | self.models = [] 131 | self.hypers = None 132 | 133 | @BaseModel._check_shapes_train 134 | def train(self, X, y, do_optimize=True): 135 | """ 136 | Trains the model on the provided data. 137 | 138 | Parameters 139 | ---------- 140 | X: np.ndarray (N, D) 141 | Input data points. The dimensionality of X is (N, D), 142 | with N as the number of points and D is the number of features. 143 | y: np.ndarray (N,) 144 | The corresponding target values. 145 | do_optimize: boolean 146 | If set to true the hyperparameters are optimized otherwise 147 | the default hyperparameters are used. 148 | 149 | """ 150 | start_time = time.time() 151 | 152 | # Normalize inputs 153 | if self.normalize_input: 154 | self.X, self.X_mean, self.X_std = zero_mean_unit_var_normalization(X) 155 | else: 156 | self.X = X 157 | 158 | # Normalize ouputs 159 | if self.normalize_output: 160 | self.y, self.y_mean, self.y_std = zero_mean_unit_var_normalization(y) 161 | else: 162 | self.y = y 163 | 164 | self.y = self.y[:, None] 165 | 166 | # Check if we have enough points to create a minibatch otherwise use all data points 167 | if self.X.shape[0] <= self.batch_size: 168 | batch_size = self.X.shape[0] 169 | else: 170 | batch_size = self.batch_size 171 | 172 | # Create the neural network 173 | features = X.shape[1] 174 | 175 | self.network = Net(n_inputs=features, n_units=[self.n_units_1, self.n_units_2, self.n_units_3]) 176 | 177 | optimizer = optim.Adam(self.network.parameters(), 178 | lr=self.init_learning_rate) 179 | 180 | # Start training 181 | lc = np.zeros([self.num_epochs]) 182 | for epoch in range(self.num_epochs): 183 | 184 | epoch_start_time = time.time() 185 | 186 | train_err = 0 187 | train_batches = 0 188 | 189 | for batch in self.iterate_minibatches(self.X, self.y, 190 | batch_size, shuffle=True): 191 | inputs = torch.Tensor(batch[0]) 192 | targets = torch.Tensor(batch[1]) 193 | 194 | optimizer.zero_grad() 195 | output = self.network(inputs) 196 | loss = torch.nn.functional.mse_loss(output, targets) 197 | loss.backward() 198 | optimizer.step() 199 | 200 | train_err += loss 201 | train_batches += 1 202 | 203 | lc[epoch] = train_err / train_batches 204 | logging.debug("Epoch {} of {}".format(epoch + 1, self.num_epochs)) 205 | curtime = time.time() 206 | epoch_time = curtime - epoch_start_time 207 | total_time = curtime - start_time 208 | logging.debug("Epoch time {:.3f}s, total time {:.3f}s".format(epoch_time, total_time)) 209 | logging.debug("Training loss:\t\t{:.5g}".format(train_err / train_batches)) 210 | 211 | # Design matrix 212 | self.Theta = self.network.basis_funcs(torch.Tensor(self.X)).data.numpy() 213 | 214 | if do_optimize: 215 | if self.do_mcmc: 216 | self.sampler = emcee.EnsembleSampler(self.n_hypers, 2, 217 | self.marginal_log_likelihood) 218 | 219 | # Do a burn-in in the first iteration 220 | if not self.burned: 221 | # Initialize the walkers by sampling from the prior 222 | self.p0 = self.prior.sample_from_prior(self.n_hypers) 223 | # Run MCMC sampling 224 | result = self.sampler.run_mcmc(self.p0, 225 | self.burnin_steps, 226 | rstate0=self.rng) 227 | self.p0 = result.coords 228 | 229 | self.burned = True 230 | 231 | # Start sampling 232 | pos = self.sampler.run_mcmc(self.p0, 233 | self.chain_length, 234 | rstate0=self.rng) 235 | 236 | # Save the current position, it will be the startpoint in 237 | # the next iteration 238 | self.p0 = pos.coords 239 | 240 | # Take the last samples from each walker set them back on a linear scale 241 | linear_theta = np.exp(self.sampler.chain[:, -1]) 242 | self.hypers = linear_theta 243 | self.hypers[:, 1] = 1 / self.hypers[:, 1] 244 | else: 245 | # Optimize hyperparameters of the Bayesian linear regression 246 | p0 = self.prior.sample_from_prior(n_samples=1) 247 | res = optimize.fmin(self.negative_mll, p0) 248 | self.hypers = [[np.exp(res[0]), 1 / np.exp(res[1])]] 249 | else: 250 | 251 | self.hypers = [[self.alpha, self.beta]] 252 | 253 | logging.info("Hypers: %s" % self.hypers) 254 | self.models = [] 255 | for sample in self.hypers: 256 | # Instantiate a model for each hyperparameter configuration 257 | model = BayesianLinearRegression(alpha=sample[0], 258 | beta=sample[1], 259 | basis_func=None) 260 | model.train(self.Theta, self.y[:, 0], do_optimize=False) 261 | 262 | self.models.append(model) 263 | 264 | def marginal_log_likelihood(self, theta): 265 | """ 266 | Log likelihood of the data marginalised over the weights w. See chapter 3.5 of 267 | the book by Bishop of an derivation. 268 | 269 | Parameters 270 | ---------- 271 | theta: np.array(2,) 272 | The hyperparameter alpha and beta on a log scale 273 | 274 | Returns 275 | ------- 276 | float 277 | lnlikelihood + prior 278 | """ 279 | if np.any(theta == np.inf): 280 | return -np.inf 281 | 282 | if np.any((-10 > theta) + (theta > 10)): 283 | return -np.inf 284 | 285 | alpha = np.exp(theta[0]) 286 | beta = 1 / np.exp(theta[1]) 287 | 288 | D = self.Theta.shape[1] 289 | N = self.Theta.shape[0] 290 | 291 | K = beta * np.dot(self.Theta.T, self.Theta) 292 | K += np.eye(self.Theta.shape[1]) * alpha 293 | try: 294 | K_inv = np.linalg.inv(K) 295 | except np.linalg.linalg.LinAlgError: 296 | K_inv = np.linalg.inv(K + np.random.rand(K.shape[0], K.shape[1]) * 1e-8) 297 | 298 | m = beta * np.dot(K_inv, self.Theta.T) 299 | m = np.dot(m, self.y) 300 | 301 | mll = D / 2 * np.log(alpha) 302 | mll += N / 2 * np.log(beta) 303 | mll -= N / 2 * np.log(2 * np.pi) 304 | mll -= beta / 2. * np.linalg.norm(self.y - np.dot(self.Theta, m), 2) 305 | mll -= alpha / 2. * np.dot(m.T, m) 306 | mll -= 0.5 * np.log(np.linalg.det(K) + 1e-10) 307 | 308 | if np.any(np.isnan(mll)): 309 | return -1e25 310 | return mll 311 | 312 | def negative_mll(self, theta): 313 | """ 314 | Returns the negative marginal log likelihood (for optimizing it with scipy). 315 | 316 | Parameters 317 | ---------- 318 | theta: np.array(2,) 319 | The hyperparameter alpha and beta on a log scale 320 | 321 | Returns 322 | ------- 323 | float 324 | negative lnlikelihood + prior 325 | """ 326 | nll = -self.marginal_log_likelihood(theta) 327 | return nll 328 | 329 | def iterate_minibatches(self, inputs, targets, batchsize, shuffle=False): 330 | assert inputs.shape[0] == targets.shape[0], \ 331 | "The number of training points is not the same" 332 | if shuffle: 333 | indices = np.arange(inputs.shape[0]) 334 | self.rng.shuffle(indices) 335 | for start_idx in range(0, inputs.shape[0] - batchsize + 1, batchsize): 336 | if shuffle: 337 | excerpt = indices[start_idx:start_idx + batchsize] 338 | else: 339 | excerpt = slice(start_idx, start_idx + batchsize) 340 | yield inputs[excerpt], targets[excerpt] 341 | 342 | @BaseModel._check_shapes_predict 343 | def predict(self, X_test): 344 | r""" 345 | Returns the predictive mean and variance of the objective function at 346 | the given test points. 347 | 348 | Parameters 349 | ---------- 350 | X_test: np.ndarray (N, D) 351 | N input test points 352 | 353 | Returns 354 | ---------- 355 | np.array(N,) 356 | predictive mean 357 | np.array(N,) 358 | predictive variance 359 | 360 | """ 361 | # Normalize inputs 362 | if self.normalize_input: 363 | X_, _, _ = zero_mean_unit_var_normalization(X_test, self.X_mean, self.X_std) 364 | else: 365 | X_ = X_test 366 | 367 | # Get features from the net 368 | 369 | theta = self.network.basis_funcs(torch.Tensor(X_)).data.numpy() 370 | 371 | # Marginalise predictions over hyperparameters of the BLR 372 | mu = np.zeros([len(self.models), X_test.shape[0]]) 373 | var = np.zeros([len(self.models), X_test.shape[0]]) 374 | 375 | for i, m in enumerate(self.models): 376 | mu[i], var[i] = m.predict(theta) 377 | 378 | # See the algorithm runtime prediction paper by Hutter et al 379 | # for the derivation of the total variance 380 | m = np.mean(mu, axis=0) 381 | v = np.mean(mu ** 2 + var, axis=0) - m ** 2 382 | 383 | # Clip negative variances and set them to the smallest 384 | # positive float value 385 | if v.shape[0] == 1: 386 | v = np.clip(v, np.finfo(v.dtype).eps, np.inf) 387 | else: 388 | v = np.clip(v, np.finfo(v.dtype).eps, np.inf) 389 | v[np.where((v < np.finfo(v.dtype).eps) & (v > -np.finfo(v.dtype).eps))] = 0 390 | 391 | if self.normalize_output: 392 | m = zero_mean_unit_var_denormalization(m, self.y_mean, self.y_std) 393 | v *= self.y_std ** 2 394 | 395 | return m, v 396 | 397 | def get_incumbent(self): 398 | """ 399 | Returns the best observed point and its function value 400 | 401 | Returns 402 | ---------- 403 | incumbent: ndarray (D,) 404 | current incumbent 405 | incumbent_value: ndarray (N,) 406 | the observed value of the incumbent 407 | """ 408 | 409 | inc, inc_value = super(DNGO, self).get_incumbent() 410 | if self.normalize_input: 411 | inc = zero_mean_unit_var_denormalization(inc, self.X_mean, self.X_std) 412 | 413 | if self.normalize_output: 414 | inc_value = zero_mean_unit_var_denormalization(inc_value, self.y_mean, self.y_std) 415 | 416 | return inc, inc_value 417 | -------------------------------------------------------------------------------- /pybnn/lc_extrapolation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/pybnn/59befe512d6f668f1dfb39ad9ba4c55abc8dd0f6/pybnn/lc_extrapolation/__init__.py -------------------------------------------------------------------------------- /pybnn/lc_extrapolation/curvefunctions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # all the models that we considered at some point 4 | all_models = {} 5 | model_defaults = {} 6 | display_name_mapping = {} 7 | 8 | 9 | def pow3(x, c, a, alpha): 10 | return c - a * x ** (-alpha) 11 | 12 | 13 | all_models["pow3"] = pow3 14 | model_defaults["pow3"] = {"c": 0.84, "a": 0.52, "alpha": 0.01} 15 | display_name_mapping["pow3"] = "pow$_3$" 16 | 17 | 18 | def linear(x, a, b): 19 | return a * x + b 20 | 21 | 22 | # models["linear"] = linear 23 | all_models["linear"] = linear 24 | 25 | """ 26 | Source: curve expert 27 | """ 28 | 29 | 30 | def log_power(x, a, b, c): 31 | # logistic power 32 | return a / (1. + (x / np.exp(b)) ** c) 33 | 34 | 35 | all_models["log_power"] = log_power 36 | model_defaults["log_power"] = {"a": 0.77, "c": -0.51, "b": 2.98} 37 | display_name_mapping["log_power"] = "log power" 38 | 39 | 40 | def weibull(x, alpha, beta, kappa, delta): 41 | """ 42 | Weibull modell 43 | 44 | http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm 45 | 46 | alpha: upper asymptote 47 | beta: lower asymptote 48 | k: growth rate 49 | delta: controls the x-ordinate for the point of inflection 50 | """ 51 | return alpha - (alpha - beta) * np.exp(-(kappa * x) ** delta) 52 | 53 | 54 | all_models["weibull"] = weibull 55 | model_defaults["weibull"] = {"alpha": .7, "beta": 0.1, "kappa": 0.01, 56 | "delta": 1} 57 | display_name_mapping["weibull"] = "Weibull" 58 | 59 | 60 | def mmf(x, alpha, beta, kappa, delta): 61 | """ 62 | Morgan-Mercer-Flodin 63 | 64 | description: 65 | Nonlinear Regression page 342 66 | http://bit.ly/1jodG17 67 | http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm 68 | 69 | alpha: upper asymptote 70 | kappa: growth rate 71 | beta: initial value 72 | delta: controls the point of inflection 73 | """ 74 | return alpha - (alpha - beta) / (1. + (kappa * x) ** delta) 75 | 76 | 77 | all_models["mmf"] = mmf 78 | model_defaults["mmf"] = {"alpha": .7, "kappa": 0.01, "beta": 0.1, "delta": 5} 79 | display_name_mapping["mmf"] = "MMF" 80 | 81 | 82 | def janoschek(x, a, beta, k, delta): 83 | """ 84 | http://www.pisces-conservation.com/growthhelp/janoschek.htm 85 | """ 86 | return a - (a - beta) * np.exp(-k * x ** delta) 87 | 88 | 89 | all_models["janoschek"] = janoschek 90 | model_defaults["janoschek"] = {"a": 0.73, "beta": 0.07, "k": 0.355, 91 | "delta": 0.46} 92 | display_name_mapping["janoschek"] = "Janoschek" 93 | 94 | 95 | def ilog2(x, c, a): 96 | x = 1 + x 97 | assert (np.all(x > 1)) 98 | return c - a / np.log(x) 99 | 100 | 101 | all_models["ilog2"] = ilog2 102 | model_defaults["ilog2"] = {"a": 0.43, "c": 0.78} 103 | display_name_mapping["ilog2"] = "ilog$_2$" 104 | 105 | 106 | def dr_hill_zero_background(x, theta, eta, kappa): 107 | x_eta = x ** eta 108 | return (theta * x_eta) / (kappa ** eta + x_eta) 109 | 110 | 111 | all_models["dr_hill_zero_background"] = dr_hill_zero_background 112 | model_defaults["dr_hill_zero_background"] = {"theta": 0.772320, "eta": 0.586449, 113 | "kappa": 2.460843} 114 | display_name_mapping["dr_hill_zero_background"] = "Hill$_3$" 115 | 116 | 117 | def logx_linear(x, a, b): 118 | x = np.log(x) 119 | return a * x + b 120 | 121 | 122 | all_models["logx_linear"] = logx_linear 123 | model_defaults["logx_linear"] = {"a": 0.378106, "b": 0.046506} 124 | display_name_mapping["logx_linear"] = "log x linear" 125 | 126 | 127 | def vap(x, a, b, c): 128 | """ Vapor pressure model """ 129 | return np.exp(a + b / x + c * np.log(x)) 130 | 131 | 132 | all_models["vap"] = vap 133 | model_defaults["vap"] = {"a": -0.622028, "c": 0.042322, "b": -0.470050} 134 | display_name_mapping["vap"] = "vapor pressure" 135 | 136 | 137 | def loglog_linear(x, a, b): 138 | x = np.log(x) 139 | return np.log(a * x + b) 140 | 141 | 142 | all_models["loglog_linear"] = loglog_linear 143 | display_name_mapping["loglog_linear"] = "log log linear" 144 | 145 | 146 | # Models that we chose not to use in the ensembles/model combinations: 147 | 148 | # source: http://aclweb.org/anthology//P/P12/P12-1003.pdf 149 | def exp3(x, c, a, b): 150 | return c - np.exp(-a * x + b) 151 | 152 | 153 | all_models["exp3"] = exp3 154 | model_defaults["exp3"] = {"c": 0.7, "a": 0.01, "b": -1} 155 | display_name_mapping["exp3"] = "exp$_3$" 156 | 157 | 158 | def exp4(x, c, a, b, alpha): 159 | return c - np.exp(-a * (x ** alpha) + b) 160 | 161 | 162 | all_models["exp4"] = exp4 163 | model_defaults["exp4"] = {"c": 0.7, "a": 0.8, "b": -0.8, "alpha": 0.3} 164 | display_name_mapping["exp4"] = "exp$_4$" 165 | 166 | 167 | # not bounded! 168 | # def logy_linear(x, a, b): 169 | # return np.log(a*x + b) 170 | # all_models["logy_linear"] = logy_linear 171 | 172 | def pow2(x, a, alpha): 173 | return a * x ** (-alpha) 174 | 175 | 176 | all_models["pow2"] = pow2 177 | model_defaults["pow2"] = {"a": 0.1, "alpha": -0.3} 178 | display_name_mapping["pow2"] = "pow$_2$" 179 | 180 | 181 | def pow4(x, c, a, b, alpha): 182 | return c - (a * x + b) ** -alpha 183 | 184 | 185 | all_models["pow4"] = pow4 186 | model_defaults["pow4"] = {"alpha": 0.1, "a": 200, "b": 0., "c": 0.8} 187 | display_name_mapping["pow4"] = "pow$_4$" 188 | 189 | 190 | def sat_growth(x, a, b): 191 | return a * x / (b + x) 192 | 193 | 194 | all_models["sat_growth"] = sat_growth 195 | model_defaults["sat_growth"] = {"a": 0.7, "b": 20} 196 | display_name_mapping["sat_growth"] = "saturated growth rate" 197 | 198 | 199 | def dr_hill(x, alpha, theta, eta, kappa): 200 | return alpha + (theta * (x ** eta)) / (kappa ** eta + x ** eta) 201 | 202 | 203 | all_models["dr_hill"] = dr_hill 204 | model_defaults["dr_hill"] = {"alpha": 0.1, "theta": 0.772320, "eta": 0.586449, 205 | "kappa": 2.460843} 206 | display_name_mapping["dr_hill"] = "Hill$_4$" 207 | 208 | 209 | def gompertz(x, a, b, c): 210 | """ 211 | Gompertz growth function. 212 | 213 | sigmoidal family 214 | a is the upper asymptote, since 215 | b, c are negative numbers 216 | b sets the displacement along the x axis (translates the graph to the left or right) 217 | c sets the growth rate (y scaling) 218 | 219 | e.g. used to model the growth of tumors 220 | 221 | http://en.wikipedia.org/wiki/Gompertz_function 222 | """ 223 | return a * np.exp(-b * np.exp(-c * x)) 224 | # return a + b * np.exp(np.exp(-k*(x-i))) 225 | 226 | 227 | all_models["gompertz"] = gompertz 228 | model_defaults["gompertz"] = {"a": 0.8, "b": 1000, "c": 0.05} 229 | display_name_mapping["gompertz"] = "Gompertz" 230 | 231 | 232 | def logistic_curve(x, a, k, b): 233 | """ 234 | a: asymptote 235 | k: 236 | b: inflection point 237 | http://www.pisces-conservation.com/growthhelp/logistic_curve.htm 238 | """ 239 | return a / (1. + np.exp(-k * (x - b))) 240 | 241 | 242 | all_models["logistic_curve"] = logistic_curve 243 | model_defaults["logistic_curve"] = {"a": 0.8, "k": 0.01, "b": 1.} 244 | display_name_mapping["logistic_curve"] = "logistic curve" 245 | 246 | 247 | def bertalanffy(x, a, k): 248 | """ 249 | a: asymptote 250 | k: growth rate 251 | http://www.pisces-conservation.com/growthhelp/von_bertalanffy.htm 252 | """ 253 | return a * (1. - np.exp(-k * x)) 254 | 255 | 256 | all_models["bertalanffy"] = bertalanffy 257 | model_defaults["bertalanffy"] = {"a": 0.8, "k": 0.01} 258 | display_name_mapping["bertalanffy"] = "Bertalanffy" 259 | 260 | curve_combination_models_old = ["vap", "ilog2", "weibull", "pow3", "pow4", 261 | "loglog_linear", 262 | "mmf", "janoschek", "dr_hill_zero_background", 263 | "log_power", 264 | "exp4"] 265 | 266 | curve_combination_models = ["weibull", "pow4", "mmf", "pow3", "loglog_linear", 267 | "janoschek", "dr_hill_zero_background", "log_power", 268 | "exp4"] 269 | 270 | curve_ensemble_models = ["vap", "ilog2", "weibull", "pow3", "pow4", 271 | "loglog_linear", 272 | "mmf", "janoschek", "dr_hill_zero_background", 273 | "log_power", 274 | "exp4"] 275 | -------------------------------------------------------------------------------- /pybnn/lc_extrapolation/curvemodels.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import logging 3 | import traceback 4 | 5 | import numpy as np 6 | from scipy.optimize import fmin_l_bfgs_b, leastsq 7 | from scipy.stats import norm 8 | 9 | 10 | def recency_weights(num): 11 | if num == 1: 12 | return np.ones(1) 13 | else: 14 | recency_weights = [10 ** (1. / num)] * num 15 | recency_weights = recency_weights ** (np.arange(0, num)) 16 | return recency_weights 17 | 18 | 19 | class CurveModel(object): 20 | def __init__(self, 21 | function, 22 | function_der=None, 23 | min_vals={}, 24 | max_vals={}, 25 | default_vals={}): 26 | """ 27 | function: the function to be fit 28 | function_der: derivative of that function 29 | """ 30 | self.function = function 31 | if function_der != None: 32 | raise NotImplementedError( 33 | "function derivate is not implemented yet...sorry!") 34 | self.function_der = function_der 35 | assert isinstance(min_vals, dict) 36 | self.min_vals = min_vals.copy() 37 | assert isinstance(max_vals, dict) 38 | self.max_vals = max_vals.copy() 39 | function_args = inspect.getargspec(function).args 40 | assert "x" in function_args, "The function needs 'x' as a parameter." 41 | for default_param_name in default_vals.keys(): 42 | if default_param_name == "sigma": 43 | continue 44 | msg = "function %s doesn't take default param %s" % ( 45 | function.__name__, default_param_name) 46 | assert default_param_name in function_args, msg 47 | self.function_params = [param for param in function_args if 48 | param != 'x'] 49 | # set default values: 50 | self.default_vals = default_vals.copy() 51 | for param_name in self.function_params: 52 | if param_name not in default_vals: 53 | logging.info("setting function parameter %s to default of 1.0 for " 54 | "function %s" % (param_name, self.function.__name__)) 55 | self.default_vals[param_name] = 1.0 56 | self.all_param_names = [param for param in self.function_params] 57 | self.all_param_names.append("sigma") 58 | self.name = self.function.__name__ 59 | self.ndim = len(self.all_param_names) 60 | 61 | # uniform noise prior over interval: 62 | if "sigma" not in self.min_vals: 63 | self.min_vals["sigma"] = 0. 64 | if "sigma" not in self.max_vals: 65 | self.max_vals["sigma"] = 1.0 66 | if "sigma" not in self.default_vals: 67 | self.default_vals["sigma"] = 0.05 68 | 69 | def default_function_param_array(self): 70 | return np.asarray([self.default_vals[param_name] for param_name in 71 | self.function_params]) 72 | 73 | def are_params_in_bounds(self, theta): 74 | """ 75 | Are the parameters in their respective bounds? 76 | """ 77 | in_bounds = True 78 | 79 | for param_name, param_value in zip(self.all_param_names, theta): 80 | if param_name in self.min_vals: 81 | if param_value < self.min_vals[param_name]: 82 | in_bounds = False 83 | if param_name in self.max_vals: 84 | if param_value > self.max_vals[param_name]: 85 | in_bounds = False 86 | return in_bounds 87 | 88 | def split_theta(self, theta): 89 | """Split theta into the function parameters (dict) and sigma. """ 90 | params = {} 91 | sigma = None 92 | for param_name, param_value in zip(self.all_param_names, theta): 93 | if param_name in self.function_params: 94 | params[param_name] = param_value 95 | elif param_name == "sigma": 96 | sigma = param_value 97 | return params, sigma 98 | 99 | def split_theta_to_array(self, theta): 100 | """Split theta into the function parameters (array) and sigma. """ 101 | params = theta[:-1] 102 | sigma = theta[-1] 103 | return params, sigma 104 | 105 | def fit(self, x, y): 106 | raise NotImplementedError() 107 | 108 | def predict(self, x): 109 | raise NotImplementedError() 110 | 111 | def predict_given_theta(self, x, theta): 112 | """ 113 | Make predictions given a single theta 114 | """ 115 | params, sigma = self.split_theta(theta) 116 | predictive_mu = self.function(x, **params) 117 | return predictive_mu, sigma 118 | 119 | def likelihood(self, x, y): 120 | """ 121 | for each y_i in y: 122 | p(y_i|x, model) 123 | """ 124 | params, sigma = self.split_theta(self.ml_params) 125 | return norm.pdf(y - self.function(x, **params), loc=0, scale=sigma) 126 | 127 | 128 | class MLCurveModel(CurveModel): 129 | """ 130 | ML fit of a curve. 131 | """ 132 | 133 | def __init__(self, recency_weighting=True, **kwargs): 134 | super(MLCurveModel, self).__init__(**kwargs) 135 | 136 | # Maximum Likelihood values of the parameters 137 | self.ml_params = None 138 | self.recency_weighting = recency_weighting 139 | 140 | def fit(self, x, y, weights=None, start_from_default=True): 141 | """ 142 | weights: None or weight for each sample. 143 | """ 144 | return self.fit_ml(x, y, weights, start_from_default) 145 | 146 | def predict(self, x): 147 | # assert len(x.shape) == 1 148 | params, sigma = self.split_theta_to_array(self.ml_params) 149 | return self.function(x, *params) 150 | # return np.asarray([self.function(x_pred, **params) for x_pred in x]) 151 | 152 | def fit_ml(self, x, y, weights, start_from_default): 153 | """ 154 | non-linear least-squares fit of the data. 155 | 156 | First tries Levenberg-Marquardt and falls back 157 | to BFGS in case that fails. 158 | 159 | Start from default values or from previous ml_params? 160 | """ 161 | successful = self.fit_leastsq(x, y, weights, start_from_default) 162 | if not successful: 163 | successful = self.fit_bfgs(x, y, weights, start_from_default) 164 | if not successful: 165 | return False 166 | return successful 167 | 168 | def ml_sigma(self, x, y, popt, weights): 169 | """ 170 | Given the ML parameters (popt) get the ML estimate of sigma. 171 | """ 172 | if weights is None: 173 | if self.recency_weighting: 174 | variance = np.average((y - self.function(x, *popt)) ** 2, 175 | weights=recency_weights(len(y))) 176 | sigma = np.sqrt(variance) 177 | else: 178 | sigma = (y - self.function(x, *popt)).std() 179 | else: 180 | if self.recency_weighting: 181 | variance = np.average((y - self.function(x, *popt)) ** 2, 182 | weights=recency_weights(len(y)) * weights) 183 | sigma = np.sqrt(variance) 184 | else: 185 | variance = np.average((y - self.function(x, *popt)) ** 2, 186 | weights=weights) 187 | sigma = np.sqrt(variance) 188 | return sigma 189 | 190 | def fit_leastsq(self, x, y, weights, start_from_default): 191 | try: 192 | if weights is None: 193 | if self.recency_weighting: 194 | residuals = lambda p: np.sqrt(recency_weights(len(y))) * ( 195 | self.function(x, *p) - y) 196 | else: 197 | residuals = lambda p: self.function(x, *p) - y 198 | else: 199 | # the return value of this function will be squared, hence 200 | # we need to take the sqrt of the weights here 201 | if self.recency_weighting: 202 | residuals = lambda p: np.sqrt( 203 | recency_weights(len(y)) * weights) * ( 204 | self.function(x, *p) - y) 205 | else: 206 | residuals = lambda p: np.sqrt(weights) * ( 207 | self.function(x, *p) - y) 208 | 209 | if start_from_default: 210 | initial_params = self.default_function_param_array() 211 | else: 212 | initial_params, _ = self.split_theta_to_array(self.ml_params) 213 | 214 | popt, cov_popt, info, msg, status = leastsq(residuals, 215 | x0=initial_params, 216 | full_output=True) 217 | # Dfun=, 218 | # col_deriv=True) 219 | 220 | if np.any(np.isnan(info["fjac"])): 221 | return False 222 | 223 | leastsq_success_statuses = [1, 2, 3, 4] 224 | if status in leastsq_success_statuses: 225 | if any(np.isnan(popt)): 226 | return False 227 | # within bounds? 228 | if not self.are_params_in_bounds(popt): 229 | return False 230 | 231 | sigma = self.ml_sigma(x, y, popt, weights) 232 | self.ml_params = np.append(popt, [sigma]) 233 | 234 | logging.info( 235 | "leastsq successful for model %s" % self.function.__name__) 236 | 237 | return True 238 | else: 239 | logging.warn("leastsq NOT successful for model %s, msg: %s" % ( 240 | self.function.__name__, msg)) 241 | logging.warn("best parameters found: " + str(popt)) 242 | return False 243 | except Exception as e: 244 | logging.error(e) 245 | tb = traceback.format_exc() 246 | logging.error(tb) 247 | return False 248 | 249 | def fit_bfgs(self, x, y, weights, start_from_default): 250 | try: 251 | def objective(params): 252 | if weights is None: 253 | if self.recency_weighting: 254 | return np.sum(recency_weights(len(y)) * ( 255 | self.function(x, *params) - y) ** 2) 256 | else: 257 | return np.sum((self.function(x, *params) - y) ** 2) 258 | else: 259 | if self.recency_weighting: 260 | return np.sum(weights * recency_weights(len(y)) * ( 261 | self.function(x, *params) - y) ** 2) 262 | else: 263 | return np.sum( 264 | weights * (self.function(x, *params) - y) ** 2) 265 | 266 | bounds = [] 267 | for param_name in self.function_params: 268 | if param_name in self.min_vals and param_name in self.max_vals: 269 | bounds.append( 270 | (self.min_vals[param_name], self.max_vals[param_name])) 271 | elif param_name in self.min_vals: 272 | bounds.append((self.min_vals[param_name], None)) 273 | elif param_name in self.max_vals: 274 | bounds.append((None, self.max_vals[param_name])) 275 | else: 276 | bounds.append((None, None)) 277 | 278 | if start_from_default: 279 | initial_params = self.default_function_param_array() 280 | else: 281 | initial_params, _ = self.split_theta_to_array(self.ml_params) 282 | 283 | popt, fval, info = fmin_l_bfgs_b(objective, 284 | x0=initial_params, 285 | bounds=bounds, 286 | approx_grad=True) 287 | if info["warnflag"] != 0: 288 | logging.warn( 289 | "BFGS not converged! (warnflag %d) for model %s" % ( 290 | info["warnflag"], self.name)) 291 | logging.warn(info) 292 | return False 293 | 294 | if popt is None: 295 | return False 296 | if any(np.isnan(popt)): 297 | logging.info( 298 | "bfgs NOT successful for model %s, parameter NaN" % self.name) 299 | return False 300 | sigma = self.ml_sigma(x, y, popt, weights) 301 | self.ml_params = np.append(popt, [sigma]) 302 | logging.info("bfgs successful for model %s" % self.name) 303 | return True 304 | except: 305 | return False 306 | 307 | def aic(self, x, y): 308 | """ 309 | Akaike information criterion 310 | http://en.wikipedia.org/wiki/Akaike_information_criterion 311 | """ 312 | params, sigma = self.split_theta_to_array(self.ml_params) 313 | y_model = self.function(x, *params) 314 | log_likelihood = norm.logpdf(y - y_model, loc=0, scale=sigma).sum() 315 | return 2 * len(self.function_params) - 2 * log_likelihood 316 | -------------------------------------------------------------------------------- /pybnn/lc_extrapolation/learning_curves.py: -------------------------------------------------------------------------------- 1 | import emcee 2 | import logging 3 | import numpy as np 4 | from pybnn.lc_extrapolation.curvefunctions import curve_combination_models, \ 5 | model_defaults, all_models 6 | from pybnn.lc_extrapolation.curvemodels import MLCurveModel 7 | from scipy.optimize import nnls 8 | from scipy.stats import norm 9 | 10 | 11 | def recency_weights(num): 12 | if num == 1: 13 | return np.ones(1) 14 | else: 15 | recency_weights = [10 ** (1. / num)] * num 16 | recency_weights = recency_weights ** (np.arange(0, num)) 17 | return recency_weights 18 | 19 | 20 | def model_ln_prob(theta, model, x, y): 21 | return model.ln_prob(theta, x, y) 22 | 23 | 24 | class MCMCCurveModelCombination(object): 25 | def __init__(self, 26 | xlim, 27 | ml_curve_models=None, 28 | burn_in=500, 29 | nwalkers=100, 30 | nsamples=2500, 31 | normalize_weights=True, 32 | monotonicity_constraint=True, 33 | soft_monotonicity_constraint=False, 34 | initial_model_weight_ml_estimate=False, 35 | normalized_weights_initialization="constant", 36 | strictly_positive_weights=True, 37 | sanity_check_prior=True, 38 | nthreads=1, 39 | recency_weighting=True): 40 | """ 41 | xlim: the point on the x axis we eventually want to make predictions for. 42 | """ 43 | if ml_curve_models is None: 44 | curve_models = [] 45 | for model_name in curve_combination_models: 46 | if model_name in model_defaults: 47 | m = MLCurveModel(function=all_models[model_name], 48 | default_vals=model_defaults[model_name], 49 | recency_weighting=False) 50 | else: 51 | m = MLCurveModel(function=all_models[model_name], 52 | recency_weighting=False) 53 | curve_models.append(m) 54 | self.ml_curve_models = curve_models 55 | else: 56 | self.ml_curve_models = ml_curve_models 57 | 58 | self.xlim = xlim 59 | self.burn_in = burn_in 60 | self.nwalkers = nwalkers 61 | self.nsamples = nsamples 62 | self.normalize_weights = normalize_weights 63 | assert not ( 64 | monotonicity_constraint and soft_monotonicity_constraint), "choose either the monotonicity_constraint or the soft_monotonicity_constraint, but not both" 65 | self.monotonicity_constraint = monotonicity_constraint 66 | self.soft_monotonicity_constraint = soft_monotonicity_constraint 67 | self.initial_model_weight_ml_estimate = initial_model_weight_ml_estimate 68 | self.normalized_weights_initialization = normalized_weights_initialization 69 | self.strictly_positive_weights = strictly_positive_weights 70 | self.sanity_check_prior = sanity_check_prior 71 | self.nthreads = nthreads 72 | self.recency_weighting = recency_weighting 73 | # the constant used for initializing the parameters in a ball around the ML parameters 74 | self.rand_init_ball = 1e-6 75 | self.name = "model combination" # (%s)" % ", ".join([model.name for model in self.ml_curve_models]) 76 | 77 | if self.monotonicity_constraint: 78 | self._x_mon = np.linspace(2, self.xlim, 50) 79 | else: 80 | self._x_mon = np.asarray([2, self.xlim]) 81 | 82 | # TODO check that burnin is lower than nsamples 83 | 84 | def fit(self, x, y, model_weights=None): 85 | if self.fit_ml_individual(x, y, model_weights): 86 | # run MCMC: 87 | logging.info('Fitted models!') 88 | self.fit_mcmc(x, y) 89 | logging.info('Fitted mcmc!') 90 | return True 91 | else: 92 | logging.warning("fit_ml_individual failed") 93 | return False 94 | 95 | def y_lim_sanity_check(self, ylim): 96 | # just make sure that the prediction is not below 0 nor insanely big 97 | # HOWEVER: there might be cases where some models might predict value larger than 1.0 98 | # and this is alright, because in those cases we don't necessarily want to stop a run. 99 | assert not isinstance(ylim, np.ndarray) 100 | if not np.isfinite(ylim) or ylim < 0. or ylim > 100.0: 101 | return False 102 | else: 103 | return True 104 | 105 | def y_lim_sanity_check_array(self, ylim): 106 | # just make sure that the prediction is not below 0 nor insanely big 107 | # HOWEVER: there might be cases where some models might predict value larger than 1.0 108 | # and this is alright, because in those cases we don't necessarily want to stop a run. 109 | assert isinstance(ylim, np.ndarray) 110 | return ~(~np.isfinite(ylim) | (ylim < 0.) | (ylim > 100.0)) 111 | 112 | def fit_ml_individual(self, x, y, model_weights): 113 | """ 114 | Do a ML fit for each model individually and then another ML fit for the combination of models. 115 | """ 116 | self.fit_models = [] 117 | for model in self.ml_curve_models: 118 | if model.fit(x, y): 119 | ylim = model.predict(self.xlim) 120 | if not self.y_lim_sanity_check(ylim): 121 | print("ML fit of model %s is out of bound range [0.0, " 122 | "100.] at xlim." % (model.function.__name__)) 123 | continue 124 | params, sigma = model.split_theta_to_array(model.ml_params) 125 | if not np.isfinite(self._ln_model_prior(model, np.array([params]))[0]): 126 | print("ML fit of model %s is not supported by prior." % 127 | model.function.__name__) 128 | continue 129 | self.fit_models.append(model) 130 | 131 | if len(self.fit_models) == 0: 132 | return False 133 | 134 | if model_weights is None: 135 | if self.normalize_weights: 136 | if self.normalized_weights_initialization == "constant": 137 | # initialize with a constant value 138 | # we will sample in this unnormalized space and then later normalize 139 | model_weights = [10. for model in self.fit_models] 140 | else: # self.normalized_weights_initialization == "normalized" 141 | model_weights = [1. / len(self.fit_models) for model in 142 | self.fit_models] 143 | else: 144 | if self.initial_model_weight_ml_estimate: 145 | model_weights = self.get_ml_model_weights(x, y) 146 | non_zero_fit_models = [] 147 | non_zero_weights = [] 148 | for w, model in zip(model_weights, self.fit_models): 149 | if w > 1e-4: 150 | non_zero_fit_models.append(model) 151 | non_zero_weights.append(w) 152 | self.fit_models = non_zero_fit_models 153 | model_weights = non_zero_weights 154 | else: 155 | model_weights = [1. / len(self.fit_models) for model in 156 | self.fit_models] 157 | 158 | # build joint ml estimated parameter vector 159 | model_params = [] 160 | all_model_params = [] 161 | for model in self.fit_models: 162 | params, sigma = model.split_theta_to_array(model.ml_params) 163 | model_params.append(params) 164 | all_model_params.extend(params) 165 | 166 | y_predicted = self._predict_given_params( 167 | x, [np.array([mp]) for mp in model_params], 168 | np.array([model_weights])) 169 | sigma = (y - y_predicted).std() 170 | 171 | self.ml_params = self._join_theta(all_model_params, sigma, model_weights) 172 | self.ndim = len(self.ml_params) 173 | if self.nwalkers < 2 * self.ndim: 174 | self.nwalkers = 2 * self.ndim 175 | logging.warning("increasing number of walkers to 2*ndim=%d" % ( 176 | self.nwalkers)) 177 | return True 178 | 179 | def get_ml_model_weights(self, x, y_target): 180 | """ 181 | Get the ML estimate of the model weights. 182 | """ 183 | 184 | """ 185 | Take all the models that have been fit using ML. 186 | For each model we get a prediction of y: y_i 187 | 188 | Now how can we combine those to reduce the squared error: 189 | 190 | argmin_w (y_target - w_1 * y_1 - w_2 * y_2 - w_3 * y_3 ...) 191 | 192 | Deriving and setting to zero we get a linear system of equations that we need to solve. 193 | 194 | 195 | Resource on QP: 196 | http://stats.stackexchange.com/questions/21565/how-do-i-fit-a-constrained-regression-in-r-so-that-coefficients-total-1 197 | http://maggotroot.blogspot.de/2013/11/constrained-linear-least-squares-in.html 198 | """ 199 | num_models = len(self.fit_models) 200 | y_predicted = [] 201 | b = [] 202 | for model in self.fit_models: 203 | y_model = model.predict(x) 204 | y_predicted.append(y_model) 205 | b.append(y_model.dot(y_target)) 206 | a = np.zeros((num_models, num_models)) 207 | for i in range(num_models): 208 | for j in range(num_models): 209 | a[i, j] = y_predicted[i].dot(y_predicted[j]) 210 | # if i == j: 211 | # a[i, j] -= 0.1 #constraint the weights! 212 | a_rank = np.linalg.matrix_rank(a) 213 | if a_rank != num_models: 214 | print("Rank %d not sufficcient for solving the linear system. %d " 215 | "needed at least." % (a_rank, num_models)) 216 | try: 217 | print(np.linalg.lstsq(a, b)[0]) 218 | print(np.linalg.solve(a, b)) 219 | print(nnls(a, b)[0]) 220 | ##return np.linalg.solve(a, b) 221 | weights = nnls(a, b)[0] 222 | # weights = [w if w > 1e-4 else 1e-4 for w in weights] 223 | return weights 224 | # except LinAlgError as e: 225 | except: 226 | return [1. / len(self.fit_models) for model in self.fit_models] 227 | 228 | # priors 229 | def _ln_prior(self, theta): 230 | # TODO remove this check, accept only 2d data 231 | if len(theta.shape) == 1: 232 | theta = theta.reshape((1, -1)) 233 | 234 | ln = np.array([0.] * len(theta)) 235 | model_params, sigma, model_weights = self._split_theta(theta) 236 | 237 | # we expect all weights to be positive 238 | # TODO add unit test for this! 239 | 240 | if self.strictly_positive_weights: 241 | violation = np.any(model_weights < 0, axis=1) 242 | ln[violation] = -np.inf 243 | 244 | for model, params in zip(self.fit_models, model_params): 245 | # Only calculate the prior further when the value is still finite 246 | mask = np.isfinite(ln) 247 | if np.sum(mask) == 0: 248 | break 249 | ln[mask] += self._ln_model_prior(model, params[mask]) 250 | 251 | # if self.normalize_weights: 252 | # when we normalize we expect all weights to be positive 253 | return ln 254 | 255 | def _ln_model_prior(self, model, params): 256 | prior = np.array([0.0] * len(params)) 257 | # reshaped_params = [ 258 | # np.array([params[j][i] 259 | # for j in range(len(params))]).reshape((-1, 1)) 260 | # for i in range(len(params[0]))] 261 | reshaped_params = [params[:, i].reshape((-1, 1)) 262 | for i in range(len(params[0]))] 263 | 264 | # prior_stats = [] 265 | # prior_stats.append((0, np.mean(~np.isfinite(prior)))) 266 | 267 | # TODO curvefunctions must be vectorized, too 268 | # y_mon = np.array([model.function(self._x_mon, *params_) 269 | # for params_ in params]) 270 | 271 | # Check, is this predict the most expensive part of the whole code? TODO 272 | # y_mon = model.function(self._x_mon, *reshaped_params) 273 | 274 | if self.monotonicity_constraint: 275 | y_mon = model.function(self._x_mon, *reshaped_params) 276 | # check for monotonicity(this obviously this is a hack, but it works for now): 277 | constraint_violated = np.any(np.diff(y_mon, axis=1) < 0, axis=1) 278 | prior[constraint_violated] = -np.inf 279 | # for i in range(len(y_mon)): 280 | # if np.any(np.diff(y_mon[i]) < 0): 281 | # prior[i] = -np.inf 282 | 283 | elif self.soft_monotonicity_constraint: 284 | y_mon = model.function(self._x_mon[[0, -1]], *reshaped_params) 285 | # soft monotonicity: defined as the last value being bigger than the first one 286 | not_monotone = [y_mon[i, 0] > y_mon[i, -1] for i in range(len(y_mon))] 287 | if any(not_monotone): 288 | for i, nm in enumerate(not_monotone): 289 | if nm: 290 | prior[i] = -np.inf 291 | 292 | else: 293 | y_mon = model.function(self._x_mon, *reshaped_params) 294 | 295 | # TODO curvefunctions must be vectorized, too 296 | # ylim = np.array([model.function(self.xlim, *params_) 297 | # for params_ in params]) 298 | # ylim = model.function(self.xlim, *reshaped_params) 299 | ylim = y_mon[:, -1] 300 | 301 | # sanity check for ylim 302 | if self.sanity_check_prior: 303 | sane = self.y_lim_sanity_check_array(ylim) 304 | prior[~sane.flatten()] = -np.inf 305 | # for i, s in enumerate(sane): 306 | # if not s: 307 | # prior[i] = -np.inf 308 | 309 | # TODO vectorize this! 310 | mask = np.isfinite(prior) 311 | for i, params_ in enumerate(params): 312 | # Only check parameters which are not yet rejected 313 | if mask[i] and not model.are_params_in_bounds(params_): 314 | prior[i] = -np.inf 315 | 316 | # prior_stats.append((3, np.mean(~np.isfinite(prior)))) 317 | # print(prior_stats) 318 | return prior 319 | 320 | # likelihood 321 | def _ln_likelihood(self, theta, x, y): 322 | y_model, sigma = self._predict_given_theta(x, theta) 323 | n_models = len(y_model) 324 | 325 | if self.recency_weighting: 326 | raise NotImplementedError() 327 | weight = recency_weights(len(y)) 328 | ln_likelihood = ( 329 | weight * norm.logpdf(y - y_model, loc=0, scale=sigma)).sum() 330 | else: 331 | # ln_likelihood = [norm.logpdf(y - y_model_, loc=0, scale=sigma_).sum() 332 | # for y_model_, sigma_ in zip(y_model, sigma)] 333 | # ln_likelihood = np.array(ln_likelihood) 334 | loc = np.zeros((n_models, 1)) 335 | sigma = sigma.reshape((-1, 1)) 336 | ln_likelihood2 = norm.logpdf(y - y_model, loc=loc, 337 | scale=sigma).sum(axis=1) 338 | # print(ln_likelihood == ln_likelihood2) 339 | ln_likelihood = ln_likelihood2 340 | 341 | ln_likelihood[~np.isfinite(ln_likelihood)] = -np.inf 342 | return ln_likelihood 343 | 344 | def _ln_prob(self, theta, x, y): 345 | """ 346 | posterior probability 347 | """ 348 | lp = self._ln_prior(theta) 349 | lp[~np.isfinite(lp)] = -np.inf 350 | ln_prob = lp + self._ln_likelihood(theta, x, y) 351 | return ln_prob 352 | 353 | def _split_theta(self, theta): 354 | """ 355 | theta is structured as follows: 356 | for each model i 357 | for each model parameter j 358 | theta = (theta_ij, sigma, w_i) 359 | """ 360 | # TODO remove this check, theta should always be 2d! 361 | if len(theta.shape) == 1: 362 | theta = theta.reshape((1, -1)) 363 | 364 | all_model_params = [] 365 | for model in self.fit_models: 366 | num_model_params = len(model.function_params) 367 | model_params = theta[:, :num_model_params] 368 | all_model_params.append(model_params) 369 | 370 | theta = theta[:, num_model_params:] 371 | 372 | sigma = theta[:, 0] 373 | model_weights = theta[:, 1:] 374 | assert model_weights.shape[1] == len(self.fit_models) 375 | return all_model_params, sigma, model_weights 376 | 377 | def _join_theta(self, model_params, sigma, model_weights): 378 | # assert len(model_params) == len(model_weights) 379 | theta = [] 380 | theta.extend(model_params) 381 | theta.append(sigma) 382 | theta.extend(model_weights) 383 | return theta 384 | 385 | def fit_mcmc(self, x, y): 386 | # initialize in an area around the starting position 387 | 388 | class PseudoPool(object): 389 | def map(self, func, proposals): 390 | return [f for f in func(np.array(proposals))] 391 | 392 | rstate0 = np.random.RandomState(1) 393 | assert self.ml_params is not None 394 | pos = [self.ml_params + self.rand_init_ball * rstate0.randn(self.ndim) 395 | for i in range(self.nwalkers)] 396 | 397 | if self.nthreads <= 1: 398 | sampler = emcee.EnsembleSampler(self.nwalkers, 399 | self.ndim, 400 | self._ln_prob, 401 | args=(x, y), 402 | pool=PseudoPool()) 403 | else: 404 | sampler = emcee.EnsembleSampler( 405 | self.nwalkers, 406 | self.ndim, 407 | model_ln_prob, 408 | args=(self, x, y), 409 | threads=self.nthreads) 410 | sampler.run_mcmc(pos, self.nsamples, rstate0=rstate0) 411 | self.mcmc_chain = sampler.chain 412 | 413 | if self.normalize_weights: 414 | self.normalize_chain_model_weights() 415 | 416 | def normalize_chain_model_weights(self): 417 | """ 418 | In the chain we sample w_1,... w_i however we are interested in the model 419 | probabilities p_1,... p_i 420 | """ 421 | model_weights_chain = self.mcmc_chain[:, :, -len(self.fit_models):] 422 | model_probabilities_chain = model_weights_chain / model_weights_chain.sum( 423 | axis=2)[:, :, np.newaxis] 424 | # replace in chain 425 | self.mcmc_chain[:, :, 426 | -len(self.fit_models):] = model_probabilities_chain 427 | 428 | def get_burned_in_samples(self): 429 | samples = self.mcmc_chain[:, self.burn_in:, :].reshape((-1, self.ndim)) 430 | return samples 431 | 432 | def print_probs(self): 433 | burned_in_chain = self.get_burned_in_samples() 434 | model_probabilities = burned_in_chain[:, -len(self.fit_models):] 435 | print(model_probabilities.mean(axis=0)) 436 | 437 | def _predict_given_theta(self, x, theta): 438 | """ 439 | returns y_predicted, sigma 440 | """ 441 | model_params, sigma, model_weights = self._split_theta(theta) 442 | 443 | y_predicted = self._predict_given_params(x, model_params, model_weights) 444 | 445 | return y_predicted, sigma 446 | 447 | def _predict_given_params(self, x, model_params, model_weights): 448 | """ 449 | returns y_predicted 450 | """ 451 | 452 | if self.normalize_weights: 453 | model_weight_sum = np.sum(model_weights, axis=1) 454 | model_ws = (model_weights.transpose() / model_weight_sum).transpose() 455 | else: 456 | model_ws = model_weights 457 | 458 | # # TODO vectorize! 459 | # vectorized_predictions = [] 460 | # for i in range(len(model_weights)): 461 | # y_model = [] 462 | # for model, model_w, params in zip(self.fit_models, model_ws[i], 463 | # model_params): 464 | # y_model.append(model_w * model.function(x, *params[i])) 465 | # y_predicted = functools.reduce(lambda a, b: a + b, y_model) 466 | # vectorized_predictions.append(y_predicted) 467 | 468 | len_x = len(x) if hasattr(x, '__len__') else 1 469 | test_predictions = np.zeros((len(model_weights), len_x)) 470 | for model, model_w, params in zip(self.fit_models, model_ws.transpose(), 471 | model_params): 472 | params2 = [params[:, i].reshape((-1, 1)) 473 | for i in range(params.shape[1])] 474 | params = params2 475 | # params = [np.array([params[j][i] for j in range(len(params))]).reshape((-1, 1)) 476 | # for i in range(len(params[0]))] 477 | # print('Diff', np.sum(np.array(params2) 478 | # - np.array(params).reshape((len(params2), -1)))) 479 | prediction = model_w.reshape((-1, 1)) * model.function(x, *params) 480 | test_predictions += prediction 481 | 482 | return test_predictions 483 | # return np.array(vectorized_predictions) 484 | 485 | def predictive_distribution(self, x, thin=1): 486 | assert isinstance(x, float) or isinstance(x, int), (x, type(x)) 487 | 488 | samples = self.get_burned_in_samples() 489 | predictions = [] 490 | for theta in samples[::thin]: 491 | model_params, sigma, model_weights = self._split_theta(theta) 492 | y_predicted = self._predict_given_params(x, model_params, 493 | model_weights) 494 | predictions.append(y_predicted) 495 | return np.asarray(predictions) 496 | 497 | def prob_x_greater_than(self, x, y, theta): 498 | """ 499 | P(f(x) > y | Data, theta) 500 | """ 501 | model_params, sigma, model_weights = self._split_theta(theta) 502 | 503 | y_predicted = self._predict_given_params(x, model_params, model_weights) 504 | 505 | cdf = norm.cdf(y, loc=y_predicted, scale=sigma) 506 | 507 | return 1. - cdf 508 | 509 | def posterior_prob_x_greater_than(self, x, y, thin=1): 510 | """ 511 | P(f(x) > y | Data) 512 | 513 | Posterior probability that f(x) is greater than y. 514 | """ 515 | assert isinstance(x, float) or isinstance(x, int) 516 | assert isinstance(y, float) or isinstance(y, int) 517 | probs = [] 518 | samples = self.get_burned_in_samples() 519 | for theta in samples[::thin]: 520 | probs.append(self.prob_x_greater_than(x, y, theta)) 521 | 522 | return np.ma.masked_invalid(probs).mean() 523 | -------------------------------------------------------------------------------- /pybnn/lcnet.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from pybnn.bohamiann import Bohamiann 8 | from pybnn.util.layers import AppendLayer 9 | 10 | 11 | def vapor_pressure(x, a, b, c, *args): 12 | b_ = (b + 1) / 2 / 10 13 | a_ = (a + 1) / 2 14 | c_ = (c + 1) / 2 / 10 15 | return torch.exp(-a_ - b_ / (x + 1e-5) - c_ * torch.log(x)) - (torch.exp(a_ + b_)) 16 | 17 | 18 | def log_func(t, a, b, c, *args): 19 | a_ = (a + 1) / 2 * 5 20 | b_ = (b + 1) / 2 21 | c_ = (c + 1) / 2 * 10 22 | return (c_ + a_ * torch.log(b_ * t + 1e-10)) / 10. 23 | 24 | 25 | def hill_3(x, a, b, c, *args): 26 | a_ = (a + 1) / 2 27 | b_ = (b + 1) / 2 28 | c_ = (c + 1) / 2 / 100 29 | return a_ * (1. / ((c_ / x + 1e-5) ** b_ + 1.)) 30 | 31 | 32 | def bf_layer(theta, t): 33 | y_a = vapor_pressure(t, theta[:, 0], theta[:, 1], theta[:, 2]) 34 | 35 | y_b = log_func(t, theta[:, 3], theta[:, 4], theta[:, 5]) 36 | 37 | y_c = hill_3(t, theta[:, 6], theta[:, 7], theta[:, 8]) 38 | 39 | return torch.stack([y_a, y_b, y_c], dim=1) 40 | 41 | 42 | def get_lc_net_architecture(input_dimensionality: int) -> torch.nn.Module: 43 | class Architecture(nn.Module): 44 | def __init__(self, n_inputs, n_hidden=50): 45 | super(Architecture, self).__init__() 46 | self.fc1 = nn.Linear(n_inputs - 1, n_hidden) 47 | self.fc2 = nn.Linear(n_hidden, n_hidden) 48 | self.fc3 = nn.Linear(n_hidden, n_hidden) 49 | self.theta_layer = nn.Linear(n_hidden, 9) 50 | self.weight_layer = nn.Linear(n_hidden, 3) 51 | self.asymptotic_layer = nn.Linear(n_hidden, 1) 52 | self.sigma_layer = AppendLayer(noise=1e-3) 53 | 54 | def forward(self, input): 55 | x = input[:, :-1] 56 | t = input[:, -1] 57 | x = torch.tanh(self.fc1(x)) 58 | x = torch.tanh(self.fc2(x)) 59 | x = torch.tanh(self.fc3(x)) 60 | theta = torch.tanh(self.theta_layer(x)) 61 | 62 | bf = bf_layer(theta, t) 63 | weights = torch.softmax(self.weight_layer(x), -1) 64 | residual = torch.tanh(torch.sum(bf * weights, dim=(1,), keepdim=True)) 65 | 66 | asymptotic = torch.sigmoid(self.asymptotic_layer(x)) 67 | 68 | mean = residual + asymptotic 69 | return self.sigma_layer(mean) 70 | 71 | return Architecture(n_inputs=input_dimensionality) 72 | 73 | 74 | class LCNet(Bohamiann): 75 | def __init__(self, **kwargs) -> None: 76 | super(LCNet, self).__init__(get_network=get_lc_net_architecture, 77 | normalize_input=True, 78 | normalize_output=False, 79 | **kwargs) 80 | 81 | @staticmethod 82 | def normalize_input(x, m=None, s=None): 83 | if m is None: 84 | m = np.mean(x, axis=0) 85 | if s is None: 86 | s = np.std(x, axis=0) 87 | 88 | x_norm = deepcopy(x) 89 | x_norm[:, :-1] = (x[:, :-1] - m[:-1]) / s[:-1] 90 | 91 | return x_norm, m, s 92 | -------------------------------------------------------------------------------- /pybnn/multi_task_bohamiann.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from functools import partial 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from pybnn.bohamiann import Bohamiann 7 | from pybnn.util.layers import AppendLayer 8 | 9 | 10 | def get_multitask_network(input_dimensionality: int, n_tasks: int) -> torch.nn.Module: 11 | 12 | class Architecture(torch.nn.Module): 13 | def __init__(self, n_inputs, n_tasks, emb_dim=5, n_hidden=50): 14 | super(Architecture, self).__init__() 15 | self.fc1 = torch.nn.Linear(n_inputs - 1 + emb_dim, n_hidden) 16 | self.fc2 = torch.nn.Linear(n_hidden, n_hidden) 17 | self.fc3 = torch.nn.Linear(n_hidden, 1) 18 | self.log_std = AppendLayer(noise=1e-3) 19 | self.emb = torch.nn.Embedding(n_tasks, emb_dim) 20 | self.n_tasks = n_tasks 21 | 22 | def forward(self, input): 23 | x = input[:, :-1] 24 | t = input[:, -1:] 25 | t_emb = self.emb(t.long()[:, 0]) 26 | x = torch.cat((x, t_emb), dim=1) 27 | x = torch.tanh(self.fc1(x)) 28 | x = torch.tanh(self.fc2(x)) 29 | x = self.fc3(x) 30 | return self.log_std(x) 31 | 32 | return Architecture(n_inputs=input_dimensionality, n_tasks=n_tasks) 33 | 34 | 35 | class MultiTaskBohamiann(Bohamiann): 36 | def __init__(self, 37 | n_tasks: int, 38 | get_network=get_multitask_network, 39 | normalize_input: bool = True, 40 | normalize_output: bool = True, 41 | use_double_precision: bool = True, 42 | sampling_method: str = "adaptive_sghmc", 43 | metrics=(nn.MSELoss,) 44 | ) -> None: 45 | """ Bayesian Neural Network for regression problems. 46 | 47 | Bayesian Neural Networks use Bayesian methods to estimate the posterior 48 | distribution of a neural network's weights. This allows to also 49 | predict uncertainties for test points and thus makes Bayesian Neural 50 | Networks suitable for Bayesian optimization. 51 | This module uses stochastic gradient MCMC methods to sample 52 | from the posterior distribution. 53 | 54 | See [1] for more details. 55 | 56 | [1] J. T. Springenberg, A. Klein, S. Falkner, F. Hutter 57 | Bayesian Optimization with Robust Bayesian Neural Networks. 58 | In Advances in Neural Information Processing Systems 29 (2016). 59 | 60 | Parameters 61 | ---------- 62 | normalize_input: bool, optional 63 | Specifies if inputs should be normalized to zero mean and unit variance. 64 | normalize_output: bool, optional 65 | Specifies whether outputs should be un-normalized. 66 | """ 67 | self.n_tasks = n_tasks 68 | 69 | func = partial(get_network, n_tasks=n_tasks) 70 | super(MultiTaskBohamiann, self).__init__(func, normalize_output, normalize_input, 71 | sampling_method, metrics, use_double_precision) 72 | 73 | def normalize_input(self, x, m=None, s=None): 74 | if m is None: 75 | m = np.mean(x, axis=0) 76 | if s is None: 77 | s = np.std(x, axis=0) 78 | x_norm = deepcopy(x) 79 | x_norm[:, :-1] = (x[:, :-1] - m[:-1]) / s[:-1] 80 | 81 | return x_norm, m, s 82 | -------------------------------------------------------------------------------- /pybnn/priors.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def log_variance_prior(log_variance: torch.Tensor, mean: float = 1e-6, variance: float = 0.01) -> torch.Tensor: 7 | return torch.mean( 8 | torch.sum( 9 | ((-((log_variance - torch.log(torch.tensor(mean, dtype=log_variance.dtype))) ** 2)) / 10 | (2. * variance)) - 0.5 * torch.log(torch.tensor(variance, dtype=log_variance.dtype)), 11 | dim=1 12 | ) 13 | ) 14 | 15 | 16 | def weight_prior(parameters: Iterable[torch.Tensor], dtype=np.float64, wdecay: float = 1.) -> torch.Tensor: 17 | 18 | num_parameters = 0 19 | log_likelihood = torch.from_numpy(np.array(0, dtype=dtype)) 20 | for parameter in parameters: 21 | num_parameters += parameter.numel() 22 | log_likelihood += torch.sum(-wdecay * 0.5 * (parameter ** 2)) 23 | 24 | return log_likelihood / num_parameters 25 | -------------------------------------------------------------------------------- /pybnn/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from .adaptive_sghmc import AdaptiveSGHMC 2 | from .preconditioned_sgld import PreconditionedSGLD 3 | from .sgld import SGLD 4 | from .sghmc import SGHMC 5 | -------------------------------------------------------------------------------- /pybnn/sampler/adaptive_sghmc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.optim import Optimizer 4 | 5 | 6 | class AdaptiveSGHMC(Optimizer): 7 | """ Stochastic Gradient Hamiltonian Monte-Carlo Sampler that uses a burn-in 8 | procedure to adapt its own hyperparameters during the initial stages 9 | of sampling. 10 | 11 | See [1] for more details on this burn-in procedure.\n 12 | See [2] for more details on Stochastic Gradient Hamiltonian Monte-Carlo. 13 | 14 | [1] J. T. Springenberg, A. Klein, S. Falkner, F. Hutter 15 | In Advances in Neural Information Processing Systems 29 (2016).\n 16 | `Bayesian Optimization with Robust Bayesian Neural Networks. `_ 17 | [2] T. Chen, E. B. Fox, C. Guestrin 18 | In Proceedings of Machine Learning Research 32 (2014).\n 19 | `Stochastic Gradient Hamiltonian Monte Carlo `_ 20 | """ 21 | 22 | def __init__(self, 23 | params, 24 | lr: float = 1e-2, 25 | num_burn_in_steps: int = 3000, 26 | epsilon: float = 1e-16, 27 | mdecay: float = 0.05, 28 | scale_grad: float = 1.) -> None: 29 | """ Set up a SGHMC Optimizer. 30 | 31 | Parameters 32 | ---------- 33 | params : iterable 34 | Parameters serving as optimization variable. 35 | lr: float, optional 36 | Base learning rate for this optimizer. 37 | Must be tuned to the specific function being minimized. 38 | Default: `1e-2`. 39 | num_burn_in_steps: int, optional 40 | Number of burn-in steps to perform. In each burn-in step, this 41 | sampler will adapt its own internal parameters to decrease its error. 42 | Set to `0` to turn scale adaption off. 43 | Default: `3000`. 44 | epsilon: float, optional 45 | (Constant) per-parameter epsilon level. 46 | Default: `0.`. 47 | mdecay:float, optional 48 | (Constant) momentum decay per time-step. 49 | Default: `0.05`. 50 | scale_grad: float, optional 51 | Value that is used to scale the magnitude of the epsilon used 52 | during sampling. In a typical batches-of-data setting this usually 53 | corresponds to the number of examples in the entire dataset. 54 | Default: `1.0`. 55 | 56 | """ 57 | if lr < 0.0: 58 | raise ValueError("Invalid learning rate: {}".format(lr)) 59 | if num_burn_in_steps < 0: 60 | raise ValueError("Invalid num_burn_in_steps: {}".format(num_burn_in_steps)) 61 | 62 | defaults = dict( 63 | lr=lr, scale_grad=float(scale_grad), 64 | num_burn_in_steps=num_burn_in_steps, 65 | mdecay=mdecay, 66 | epsilon=epsilon 67 | ) 68 | super().__init__(params, defaults) 69 | 70 | def step(self, closure=None): 71 | loss = None 72 | 73 | if closure is not None: 74 | loss = closure() 75 | 76 | for group in self.param_groups: 77 | for parameter in group["params"]: 78 | 79 | if parameter.grad is None: 80 | continue 81 | 82 | state = self.state[parameter] 83 | 84 | if len(state) == 0: 85 | state["iteration"] = 0 86 | state["tau"] = torch.ones_like(parameter) 87 | state["g"] = torch.ones_like(parameter) 88 | state["v_hat"] = torch.ones_like(parameter) 89 | state["momentum"] = torch.zeros_like(parameter) 90 | state["iteration"] += 1 91 | 92 | mdecay, epsilon, lr = group["mdecay"], group["epsilon"], group["lr"] 93 | scale_grad = torch.tensor(group["scale_grad"], dtype=parameter.dtype) 94 | tau, g, v_hat = state["tau"], state["g"], state["v_hat"] 95 | 96 | momentum = state["momentum"] 97 | gradient = parameter.grad.data * scale_grad 98 | 99 | tau_inv = 1. / (tau + 1.) 100 | 101 | # update parameters during burn-in 102 | if state["iteration"] <= group["num_burn_in_steps"]: 103 | tau.add_(- tau * ( 104 | g * g / (v_hat + epsilon)) + 1) # specifies the moving average window, see Eq 9 in [1] left 105 | g.add_(-g * tau_inv + tau_inv * gradient) # average gradient see Eq 9 in [1] right 106 | v_hat.add_(-v_hat * tau_inv + tau_inv * (gradient ** 2)) # gradient variance see Eq 8 in [1] 107 | 108 | minv_t = 1. / (torch.sqrt(v_hat) + epsilon) # preconditioner 109 | 110 | epsilon_var = (2. * (lr ** 2) * mdecay * minv_t - (lr ** 4)) 111 | 112 | # sample random epsilon 113 | sigma = torch.sqrt(torch.clamp(epsilon_var, min=1e-16)) 114 | sample_t = torch.normal(mean=torch.zeros_like(gradient), std=torch.ones_like(gradient) * sigma) 115 | 116 | # update momentum (Eq 10 right in [1]) 117 | momentum.add_( 118 | - (lr ** 2) * minv_t * gradient - mdecay * momentum + sample_t 119 | ) 120 | 121 | # update theta (Eq 10 left in [1]) 122 | parameter.data.add_(momentum) 123 | 124 | return loss 125 | -------------------------------------------------------------------------------- /pybnn/sampler/preconditioned_sgld.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.optim import Optimizer 4 | 5 | 6 | class PreconditionedSGLD(Optimizer): 7 | """ Stochastic Gradient Langevin Dynamics Sampler with preconditioning. 8 | Optimization variable is viewed as a posterior sample under Stochastic 9 | Gradient Langevin Dynamics with noise rescaled in each dimension 10 | according to RMSProp. 11 | """ 12 | def __init__(self, 13 | params, 14 | lr=np.float64(1e-2), 15 | num_train_points=1, 16 | precondition_decay_rate=np.float64(0.99), 17 | diagonal_bias=np.float64(1e-5)) -> None: 18 | """ Set up a SGLD Optimizer. 19 | 20 | Parameters 21 | ---------- 22 | params : iterable 23 | Parameters serving as optimization variable. 24 | lr : float, optional 25 | Base learning rate for this optimizer. 26 | Must be tuned to the specific function being minimized. 27 | Default: `1e-2`. 28 | precondition_decay_rate : float, optional 29 | Exponential decay rate of the rescaling of the preconditioner (RMSprop). 30 | Should be smaller than but nearly `1` to approximate sampling from the posterior. 31 | Default: `0.99` 32 | diagonal_bias : float, optional 33 | Term added to the diagonal of the preconditioner to prevent it from 34 | degenerating. 35 | Default: `1e-5`. 36 | 37 | """ 38 | if lr < 0.0: 39 | raise ValueError("Invalid learning rate: {}".format(lr)) 40 | 41 | defaults = dict( 42 | lr=lr, precondition_decay_rate=precondition_decay_rate, 43 | diagonal_bias=diagonal_bias, 44 | num_train_points=num_train_points 45 | ) 46 | super().__init__(params, defaults) 47 | 48 | def step(self, closure=None): 49 | loss = None 50 | 51 | if closure is not None: 52 | loss = closure() 53 | 54 | for group in self.param_groups: 55 | for parameter in group["params"]: 56 | 57 | if parameter.grad is None: 58 | continue 59 | 60 | state = self.state[parameter] 61 | lr = group["lr"] 62 | num_train_points = group["num_train_points"] 63 | precondition_decay_rate = group["precondition_decay_rate"] # alpha 64 | diagonal_bias = group["diagonal_bias"] # lambda 65 | gradient = parameter.grad.data * num_train_points 66 | 67 | # state initialization 68 | if len(state) == 0: 69 | state["iteration"] = 0 70 | state["momentum"] = torch.ones_like(parameter) 71 | 72 | state["iteration"] += 1 73 | 74 | # momentum update 75 | momentum = state["momentum"] 76 | momentum_t = momentum * precondition_decay_rate + (1.0 - precondition_decay_rate) * (gradient ** 2) 77 | state["momentum"] = momentum_t # V(theta_t+1) 78 | 79 | # compute preconditioner 80 | preconditioner = (1. / (torch.sqrt(momentum_t) + diagonal_bias)) # G(theta_t+1) 81 | 82 | # standard deviation of the injected noise 83 | sigma = torch.sqrt(torch.from_numpy(np.array(lr, dtype=type(lr)))) * torch.sqrt(preconditioner) 84 | 85 | mean = 0.5 * lr * (preconditioner * gradient) 86 | delta = (mean + torch.normal(mean=torch.zeros_like(gradient), std=torch.ones_like(gradient)) * sigma) 87 | 88 | parameter.data.add_(-delta) 89 | 90 | return loss 91 | -------------------------------------------------------------------------------- /pybnn/sampler/sghmc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.optim import Optimizer 4 | 5 | 6 | class SGHMC(Optimizer): 7 | """ Stochastic Gradient Hamiltonian Monte-Carlo Sampler that uses a burn-in 8 | procedure to adapt its own hyperparameters during the initial stages 9 | of sampling. 10 | 11 | See [1] for more details on Stochastic Gradient Hamiltonian Monte-Carlo. 12 | 13 | [1] T. Chen, E. B. Fox, C. Guestrin 14 | In Proceedings of Machine Learning Research 32 (2014).\n 15 | `Stochastic Gradient Hamiltonian Monte Carlo `_ 16 | """ 17 | name = "AdaptiveSGHMC" 18 | 19 | def __init__(self, 20 | params, 21 | lr: float=1e-2, 22 | mdecay: float=0.01, 23 | wd: float=0.00002, 24 | scale_grad: float=1.) -> None: 25 | """ Set up a SGHMC Optimizer. 26 | 27 | Parameters 28 | ---------- 29 | params : iterable 30 | Parameters serving as optimization variable. 31 | lr: float, optional 32 | Base learning rate for this optimizer. 33 | Must be tuned to the specific function being minimized. 34 | Default: `1e-2`. 35 | mdecay:float, optional 36 | (Constant) momentum decay per time-step. 37 | Default: `0.05`. 38 | scale_grad: float, optional 39 | Value that is used to scale the magnitude of the noise used 40 | during sampling. In a typical batches-of-data setting this usually 41 | corresponds to the number of examples in the entire dataset. 42 | Default: `1.0`. 43 | 44 | """ 45 | if lr < 0.0: 46 | raise ValueError("Invalid learning rate: {}".format(lr)) 47 | 48 | defaults = dict( 49 | lr=lr, scale_grad=scale_grad, 50 | mdecay=mdecay, 51 | wd=wd 52 | ) 53 | super().__init__(params, defaults) 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | 58 | if closure is not None: 59 | loss = closure() 60 | 61 | for group in self.param_groups: 62 | for parameter in group["params"]: 63 | 64 | if parameter.grad is None: 65 | continue 66 | 67 | state = self.state[parameter] 68 | 69 | if len(state) == 0: 70 | state["iteration"] = 0 71 | state["momentum"] = torch.randn(parameter.size(), dtype=parameter.dtype) 72 | 73 | state["iteration"] += 1 74 | 75 | mdecay, lr, wd = group["mdecay"], group["lr"], group["wd"] 76 | scale_grad = group["scale_grad"] 77 | 78 | momentum = state["momentum"] 79 | gradient = parameter.grad.data * scale_grad 80 | 81 | sigma = torch.sqrt(torch.from_numpy(np.array(2 * lr * mdecay, dtype=type(lr)))) 82 | sample_t = torch.normal(mean=torch.zeros_like(gradient), std=torch.ones_like(gradient) * sigma) 83 | 84 | parameter.data.add_(lr * mdecay * momentum) 85 | momentum.add_(-lr * gradient - mdecay * lr * momentum + sample_t) 86 | return loss 87 | -------------------------------------------------------------------------------- /pybnn/sampler/sgld.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.optim import Optimizer 4 | 5 | 6 | # def decay(t, a, b , gamma): 7 | # return a * (b + t) ** (-gamma) 8 | # 9 | # def poly(base_lr, t, max_iter, power): 10 | # return base_lr * (1 - t / max_iter) ** power 11 | # 12 | # def const(*args): 13 | # return 1 14 | 15 | 16 | class SGLD(Optimizer): 17 | """ Stochastic Gradient Langevin Dynamics Sampler 18 | """ 19 | 20 | def __init__(self, 21 | params, 22 | lr: np.float64 = 1e-2, 23 | scale_grad: np.float64 = 1) -> None: 24 | 25 | """ Set up a SGLD Optimizer. 26 | 27 | Parameters 28 | ---------- 29 | params : iterable 30 | Parameters serving as optimization variable. 31 | lr : float, optional 32 | Base learning rate for this optimizer. 33 | Must be tuned to the specific function being minimized. 34 | Default: `1e-2`. 35 | """ 36 | if lr < 0.0: 37 | raise ValueError("Invalid learning rate: {}".format(lr)) 38 | 39 | # if lr_decay is None: 40 | # self.lr_decay = const 41 | # pass 42 | # elif lr_decay == "inv": 43 | # final_lr_fraction = 1e-2 44 | # degree = 2 45 | # gamma = (np.power(1 / final_lr_fraction, 1. / degree) - 1) / (T - 1) 46 | # self.lr_decay = lambda t: lr * np.power((1 + gamma * t), -degree) 47 | # else: 48 | # self.lr_decay = lr_decay 49 | defaults = dict( 50 | lr=lr, 51 | scale_grad=scale_grad 52 | ) 53 | super().__init__(params, defaults) 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | 58 | if closure is not None: 59 | loss = closure() 60 | 61 | for group in self.param_groups: 62 | for parameter in group["params"]: 63 | 64 | if parameter.grad is None: 65 | continue 66 | 67 | state = self.state[parameter] 68 | lr, scale_grad = group["lr"], group["scale_grad"] 69 | # the average gradient over the batch, i.e N/n sum_i g_theta_i + g_prior 70 | gradient = parameter.grad.data * scale_grad 71 | # State initialization 72 | if len(state) == 0: 73 | state["iteration"] = 0 74 | 75 | sigma = torch.sqrt(torch.from_numpy(np.array(lr, dtype=type(lr)))) 76 | delta = (0.5 * lr * gradient + 77 | sigma * torch.normal(mean=torch.zeros_like(gradient), std=torch.ones_like(gradient))) 78 | 79 | parameter.data.add_(-delta) 80 | state["iteration"] += 1 81 | state["sigma"] = sigma 82 | 83 | return loss 84 | -------------------------------------------------------------------------------- /pybnn/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/pybnn/59befe512d6f668f1dfb39ad9ba4c55abc8dd0f6/pybnn/util/__init__.py -------------------------------------------------------------------------------- /pybnn/util/infinite_dataloader.py: -------------------------------------------------------------------------------- 1 | def infinite_dataloader(dataloader): 2 | """ Yield an unbounded amount of batches from a `torch.utils.data.DataLoader`. 3 | Parameters 4 | ---------- 5 | dataloader : torch.utils.data.DataLoader 6 | Iterable yielding batches of data from a dataset of interest. 7 | """ 8 | while True: 9 | for batch in dataloader: 10 | yield batch -------------------------------------------------------------------------------- /pybnn/util/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class AppendLayer(nn.Module): 7 | def __init__(self, noise=1e-3, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self.log_var = nn.Parameter(torch.DoubleTensor(1, 1)) 10 | 11 | nn.init.constant_(self.log_var, val=np.log(noise)) 12 | 13 | def forward(self, x): 14 | return torch.cat((x, self.log_var * torch.ones_like(x)), dim=1) 15 | -------------------------------------------------------------------------------- /pybnn/util/normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def zero_one_normalization(X, lower=None, upper=None): 5 | 6 | if lower is None: 7 | lower = np.min(X, axis=0) 8 | if upper is None: 9 | upper = np.max(X, axis=0) 10 | 11 | X_normalized = np.true_divide((X - lower), (upper - lower)) 12 | 13 | return X_normalized, lower, upper 14 | 15 | 16 | def zero_one_denormalization(X_normalized, lower, upper): 17 | return lower + (upper - lower) * X_normalized 18 | 19 | 20 | def zero_mean_unit_var_normalization(X, mean=None, std=None): 21 | if mean is None: 22 | mean = np.mean(X, axis=0) 23 | if std is None: 24 | std = np.std(X, axis=0) 25 | 26 | X_normalized = (X - mean) / std 27 | 28 | return X_normalized, mean, std 29 | 30 | 31 | def zero_mean_unit_var_denormalization(X_normalized, mean, std): 32 | return X_normalized * std + mean 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | emcee 5 | scipy 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='pybnn', 5 | version='0.0.5', 6 | description='Simple python framework for Bayesian neural networks', 7 | author='Aaron Klein, Moritz Freidank', 8 | author_email='kleinaa@cs.uni-freiburg.de', 9 | url="https://github.com/automl/pybnn", 10 | license='BSD 3-Clause License', 11 | classifiers=['Development Status :: 4 - Beta'], 12 | packages=find_packages(), 13 | python_requires='>=3', 14 | install_requires=['torch', 'torchvision', 'numpy', 'emcee', 'scipy'], 15 | extras_require={}, 16 | keywords=['python', 'Bayesian', 'neural networks'], 17 | ) 18 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/pybnn/59befe512d6f668f1dfb39ad9ba4c55abc8dd0f6/test/__init__.py -------------------------------------------------------------------------------- /test/test_bohamiann.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from scipy.optimize import check_grad 5 | 6 | from pybnn.bohamiann import Bohamiann 7 | 8 | 9 | class TestBohamiann(unittest.TestCase): 10 | 11 | def setUp(self): 12 | self.X = np.random.rand(10, 3) 13 | self.y = np.sinc(self.X * 10 - 5).sum(axis=1) 14 | self.model = Bohamiann(normalize_input=True, normalize_output=True, use_double_precision=True) 15 | self.model.train(self.X, self.y, num_burn_in_steps=20, num_steps=100, keep_every=10) 16 | 17 | def test_predict(self): 18 | X_test = np.random.rand(10, self.X.shape[1]) 19 | 20 | m, v = self.model.predict(X_test) 21 | 22 | assert len(m.shape) == 1 23 | assert m.shape[0] == X_test.shape[0] 24 | assert len(v.shape) == 1 25 | assert v.shape[0] == X_test.shape[0] 26 | 27 | def test_gradient_mean(self): 28 | X_test = np.random.rand(10, self.X.shape[1]) 29 | 30 | def wrapper(x): 31 | return self.model.predict([x])[0] 32 | 33 | def wrapper_grad(x): 34 | return self.model.predictive_mean_gradient(x) 35 | 36 | grad = self.model.predictive_mean_gradient(X_test[0]) 37 | assert grad.shape[0] == X_test.shape[1] 38 | 39 | for xi in X_test: 40 | err = check_grad(wrapper, wrapper_grad, xi, epsilon=1e-6) 41 | assert err < 1e-5 42 | 43 | def test_gradient_variance(self): 44 | X_test = np.random.rand(10, self.X.shape[1]) 45 | 46 | def wrapper(x): 47 | v = self.model.predict([x])[1] 48 | return v 49 | 50 | def wrapper_grad(x): 51 | return self.model.predictive_variance_gradient(x) 52 | 53 | grad = self.model.predictive_variance_gradient(X_test[0]) 54 | assert grad.shape[0] == X_test.shape[1] 55 | 56 | for xi in X_test: 57 | err = check_grad(wrapper, wrapper_grad, xi, epsilon=1e-6) 58 | assert err < 1e-5 59 | 60 | 61 | class TestBohamiannSampler(unittest.TestCase): 62 | 63 | def test_sgld(self): 64 | self.X = np.random.rand(10, 3) 65 | self.y = np.sinc(self.X * 10 - 5).sum(axis=1) 66 | self.model = Bohamiann(normalize_input=True, normalize_output=True, 67 | use_double_precision=True, sampling_method="sgld") 68 | self.model.train(self.X, self.y, num_burn_in_steps=20, num_steps=100, keep_every=10) 69 | 70 | def test_preconditioned_sgld(self): 71 | self.X = np.random.rand(10, 3) 72 | self.y = np.sinc(self.X * 10 - 5).sum(axis=1) 73 | self.model = Bohamiann(normalize_input=True, normalize_output=True, 74 | use_double_precision=True, sampling_method="preconditioned_sgld") 75 | self.model.train(self.X, self.y, num_burn_in_steps=20, num_steps=100, keep_every=10) 76 | 77 | def test_sghmc(self): 78 | self.X = np.random.rand(10, 3) 79 | self.y = np.sinc(self.X * 10 - 5).sum(axis=1) 80 | self.model = Bohamiann(normalize_input=True, normalize_output=True, 81 | use_double_precision=True, sampling_method="sghmc") 82 | self.model.train(self.X, self.y, num_burn_in_steps=20, num_steps=100, keep_every=10) 83 | 84 | def test_adaptive_sghmc(self): 85 | self.X = np.random.rand(10, 3) 86 | self.y = np.sinc(self.X * 10 - 5).sum(axis=1) 87 | self.model = Bohamiann(normalize_input=True, normalize_output=True, 88 | use_double_precision=True, sampling_method="adaptive_sghmc") 89 | self.model.train(self.X, self.y, num_burn_in_steps=20, num_steps=100, keep_every=10) 90 | 91 | 92 | if __name__ == "__main__": 93 | unittest.main() 94 | -------------------------------------------------------------------------------- /test/test_dngo.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from pybnn.dngo import DNGO 5 | 6 | 7 | class TestDNGO(unittest.TestCase): 8 | 9 | def setUp(self): 10 | self.X = np.random.rand(10, 3) 11 | self.y = np.sinc(self.X * 10 - 5).sum(axis=1) 12 | 13 | def test_mcmc(self): 14 | model = DNGO(num_epochs=10, burnin_steps=10, chain_length=20, do_mcmc=True) 15 | model.train(self.X, self.y) 16 | 17 | X_test = np.random.rand(10, self.X.shape[1]) 18 | 19 | m, v = model.predict(X_test) 20 | 21 | assert len(m.shape) == 1 22 | assert m.shape[0] == X_test.shape[0] 23 | assert len(v.shape) == 1 24 | assert v.shape[0] == X_test.shape[0] 25 | 26 | def test_ml(self): 27 | model = DNGO(num_epochs=10, do_mcmc=False) 28 | model.train(self.X, self.y) 29 | 30 | X_test = np.random.rand(10, self.X.shape[1]) 31 | 32 | m, v = model.predict(X_test) 33 | 34 | assert len(m.shape) == 1 35 | assert m.shape[0] == X_test.shape[0] 36 | assert len(v.shape) == 1 37 | assert v.shape[0] == X_test.shape[0] 38 | 39 | def test_without_normalization(self): 40 | model = DNGO(num_epochs=10, do_mcmc=False, normalize_output=False, normalize_input=False) 41 | model.train(self.X, self.y) 42 | 43 | X_test = np.random.rand(10, self.X.shape[1]) 44 | 45 | m, v = model.predict(X_test) 46 | 47 | assert len(m.shape) == 1 48 | assert m.shape[0] == X_test.shape[0] 49 | assert len(v.shape) == 1 50 | assert v.shape[0] == X_test.shape[0] 51 | 52 | def test_incumbent(self): 53 | model = DNGO(num_epochs=10, do_mcmc=False) 54 | model.train(self.X, self.y) 55 | 56 | x_star, y_star = model.get_incumbent() 57 | 58 | b = np.argmin(self.y) 59 | 60 | assert np.all(np.isclose(x_star, self.X[b])) 61 | assert np.all(np.isclose(y_star, self.y[b])) 62 | 63 | 64 | if __name__ == "__main__": 65 | unittest.main() 66 | -------------------------------------------------------------------------------- /test/test_lcnet.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from pybnn.lcnet import LCNet 6 | 7 | 8 | class TestLCNet(unittest.TestCase): 9 | 10 | def test_train_predict(self): 11 | 12 | def toy_example(t, a, b): 13 | return (10 + a * np.log(b * t)) / 10. + 10e-3 * np.random.rand() 14 | 15 | observed = 20 16 | N = 5 17 | n_epochs = 10 18 | observed_t = int(n_epochs * (observed / 100.)) 19 | 20 | t_idx = np.arange(1, observed_t + 1) / n_epochs 21 | t_grid = np.arange(1, n_epochs + 1) / n_epochs 22 | 23 | configs = np.random.rand(N, 2) 24 | learning_curves = [toy_example(t_grid, configs[i, 0], configs[i, 1]) for i in range(N)] 25 | 26 | X_train = None 27 | y_train = None 28 | X_test = None 29 | y_test = None 30 | 31 | for i in range(N): 32 | 33 | x = np.repeat(configs[i, None, :], t_idx.shape[0], axis=0) 34 | x = np.concatenate((x, t_idx[:, None]), axis=1) 35 | 36 | x_test = np.concatenate((configs[i, None, :], np.array([[1]])), axis=1) 37 | 38 | lc = learning_curves[i][:observed_t] 39 | lc_test = np.array([learning_curves[i][-1]]) 40 | 41 | if X_train is None: 42 | X_train = x 43 | y_train = lc 44 | X_test = x_test 45 | y_test = lc_test 46 | else: 47 | X_train = np.concatenate((X_train, x), 0) 48 | y_train = np.concatenate((y_train, lc), 0) 49 | X_test = np.concatenate((X_test, x_test), 0) 50 | y_test = np.concatenate((y_test, lc_test), 0) 51 | 52 | print(X_train.shape) 53 | model = LCNet() 54 | 55 | model.train(X_train, y_train, num_steps=500, num_burn_in_steps=40, lr=1e-2) 56 | 57 | m, v = model.predict(X_test) 58 | 59 | assert len(m.shape) == 1 60 | assert m.shape[0] == X_test.shape[0] 61 | assert len(v.shape) == 1 62 | assert v.shape[0] == X_test.shape[0] 63 | 64 | 65 | if __name__ == "__main__": 66 | unittest.main() 67 | -------------------------------------------------------------------------------- /test/test_mtbohamiann.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from pybnn.multi_task_bohamiann import MultiTaskBohamiann 6 | 7 | 8 | class TestMTBohamiann(unittest.TestCase): 9 | 10 | def test_train_predict(self): 11 | 12 | def objective(x, task): 13 | if task == 0: 14 | y = 0.5 * np.sin(3 * x[0]) * 4 * (x[0] - 1) * (x[0]) 15 | elif task == 1: 16 | y = 0.5 * np.sin(3 * x[0] + 1) * 4 * (x[0] - 1) * (x[0]) 17 | elif task == 2: 18 | y = 0.5 * np.sin(3 * x[0] + 2) * 4 * (x[0] - 1) * (x[0]) 19 | return y 20 | 21 | upper = np.ones(1) * 6 22 | 23 | X = np.random.rand(30, 1) * upper 24 | y_t0 = np.array([objective(xi, 0) for xi in X[:10]]) 25 | y_t1 = np.array([objective(xi, 1) for xi in X[10:20]]) 26 | y_t2 = np.array([objective(xi, 2) for xi in X[20:]]) 27 | y = np.hstack((y_t0, y_t1, y_t2)) 28 | 29 | t_idx = np.zeros([30]) 30 | t_idx[10:20] = 1 31 | t_idx[20:] = 2 32 | 33 | X = np.append(X, t_idx[:, None], axis=1) 34 | 35 | model = MultiTaskBohamiann(n_tasks=3) 36 | 37 | model.train(X, y, num_steps=500, num_burn_in_steps=40, lr=1e-2) 38 | 39 | X_test = np.random.rand(5, 1) * upper 40 | X_test = np.append(X_test, np.ones([X_test.shape[0], 1]), axis=1) 41 | m, v = model.predict(X_test) 42 | 43 | assert len(m.shape) == 1 44 | assert m.shape[0] == X_test.shape[0] 45 | assert len(v.shape) == 1 46 | assert v.shape[0] == X_test.shape[0] 47 | 48 | 49 | if __name__ == "__main__": 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /test/test_normalization.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | from pybnn.util import normalization 5 | 6 | 7 | class TestNormalization(unittest.TestCase): 8 | 9 | def test_zero_one_normalization(self): 10 | 11 | X = np.random.randn(100, 3) 12 | X_norm, lo, up = normalization.zero_one_normalization(X) 13 | 14 | assert X_norm.shape == X.shape 15 | assert np.min(X_norm) >= 0 16 | assert np.max(X_norm) <= 1 17 | assert lo.shape[0] == X.shape[1] 18 | assert up.shape[0] == X.shape[1] 19 | 20 | def test_zero_one_unnormalization(self): 21 | X_norm = np.random.rand(100, 3) 22 | lo = np.ones([3]) * -1 23 | up = np.ones([3]) 24 | X = normalization.zero_one_denormalization(X_norm, lo, up) 25 | 26 | assert X_norm.shape == X.shape 27 | assert np.all(np.min(X, axis=0) >= lo) 28 | assert np.all(np.max(X, axis=0) <= up) 29 | 30 | def test_zero_mean_unit_var_normalization(self): 31 | X = np.random.rand(100, 3) 32 | X_norm, m, s = normalization.zero_mean_unit_var_normalization(X) 33 | 34 | np.testing.assert_almost_equal(np.mean(X_norm, axis=0), np.zeros(X_norm.shape[1]), decimal=1) 35 | np.testing.assert_almost_equal(np.var(X_norm, axis=0), np.ones(X_norm.shape[1]), decimal=1) 36 | 37 | assert X_norm.shape == X.shape 38 | assert m.shape[0] == X.shape[1] 39 | assert s.shape[0] == X.shape[1] 40 | 41 | def test_zero_one_unit_var_unnormalization(self): 42 | X_norm = np.random.randn(100, 3) 43 | m = np.ones(X_norm.shape[1]) * 3 44 | s = np.ones(X_norm.shape[1]) * 0.1 45 | X = normalization.zero_mean_unit_var_denormalization(X_norm, m, s) 46 | 47 | np.testing.assert_almost_equal(np.mean(X, axis=0), m, decimal=1) 48 | np.testing.assert_almost_equal(np.var(X, axis=0), s**2, decimal=1) 49 | 50 | assert X_norm.shape == X.shape 51 | 52 | if __name__ == "__main__": 53 | unittest.main() 54 | --------------------------------------------------------------------------------