├── .gitignore ├── README.md ├── dvi ├── __init__.py ├── bayes_layers.py ├── bayes_models.py ├── bayes_utils.py ├── dataset.py ├── invgamma.py ├── kl.py ├── loss.py ├── plot.py └── variables.py ├── main.ipynb ├── main.py └── test ├── InvGamma_validation.ipynb └── check.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # deterministic-variational-inference-pytorch 2 | 3 | Re-implementation of [Deterministic-Variational-Inference](https://github.com/Microsoft/deterministic-variational-inference) 4 | -------------------------------------------------------------------------------- /dvi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makora9143/deterministic-variational-inference-pytorch/6a467593382acad4ee241fb58ddfb6775f3a8efd/dvi/__init__.py -------------------------------------------------------------------------------- /dvi/bayes_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import dvi.bayes_utils as bu 5 | from .variables import make_weight_matrix, make_bias_vector, GaussianVar 6 | 7 | 8 | class VariationalLinear(nn.Module): 9 | def __init__(self, input_features, output_features, 10 | prior_type='empirical', 11 | variance='wider_he', bias=True): 12 | super(VariationalLinear, self).__init__() 13 | self.input_features = input_features 14 | self.output_features = output_features 15 | 16 | self.weight = make_weight_matrix((output_features, input_features), prior_type, variance) 17 | if bias: 18 | self.bias = make_bias_vector((output_features, input_features), prior_type, variance) 19 | else: 20 | self.register_parameter("bias", None) 21 | 22 | def forward(self, input): 23 | x_mean = input.mean 24 | y_mean = x_mean.mm(self.weight.q_loc.t()) 25 | if self.bias: 26 | y_mean += self.bias.q_loc.unsqueeze(0).expand_as(y_mean) 27 | x_cov = input.var 28 | y_cov = self.forward_covariance(x_mean, x_cov) 29 | return GaussianVar(y_mean, y_cov) 30 | 31 | def surprise(self): 32 | kl = torch.sum(self.weight.surprise()) 33 | if self.bias: 34 | kl += torch.sum(self.bias.surprise()) 35 | return kl 36 | 37 | def forward_covariance(self, x_mean, x_cov): 38 | output_dim, input_dim = self.weight.q_loc.shape 39 | 40 | x_var_diag = torch.diagonal(x_cov, dim1=-2, dim2=-1) 41 | xx_mean = x_var_diag + x_mean * x_mean 42 | 43 | term1_diag = xx_mean.mm(torch.pow(torch.exp(self.weight.log_q_scale), 2).t()) 44 | 45 | flat_xCov = x_cov.reshape(-1, input_dim) 46 | xCov_W = flat_xCov.mm(self.weight.q_loc.t()) 47 | xCov_W = xCov_W.reshape(-1, input_dim, output_dim) 48 | xCov_W = xCov_W.transpose(1, 2) 49 | xCov_W = xCov_W.reshape(-1, input_dim) 50 | W_xCov_W = xCov_W.mm(self.weight.q_loc.t()) 51 | W_xCov_W = W_xCov_W.reshape(-1, output_dim, output_dim) 52 | 53 | term2 = W_xCov_W 54 | term2_diag = torch.diagonal(term2, dim1=-2, dim2=-1) 55 | 56 | term3_diag = torch.pow(torch.exp(self.bias.log_q_scale), 2).unsqueeze(0).expand_as(term2_diag) 57 | 58 | result_diag = term1_diag + term2_diag + term3_diag 59 | return bu.matrix_set_diag(term2, result_diag, dim1=-2, dim2=-1) 60 | 61 | def forward_mcmc(self, input, n_samples=None, average=False): 62 | if n_samples is None: 63 | n_samples = 1 64 | 65 | repeated_x = input.unsqueeze(0).repeat(n_samples, 1, 1) 66 | 67 | sampled_w = self.weight.sample(n_samples, average) 68 | 69 | h = torch.matmul(repeated_x, sampled_w.transpose(1, 2)) 70 | 71 | if self.bias: 72 | sampled_b = self.bias.sample(n_samples, average) 73 | h += sampled_b.unsqueeze(1).expand_as(h) 74 | return h 75 | 76 | 77 | class VariationalLinearCertainActivations(VariationalLinear): 78 | def forward(self, input): 79 | x_mean = input 80 | xx = x_mean * x_mean 81 | y_mean = x_mean.mm(self.weight.q_loc.t()) 82 | if self.bias: 83 | y_mean += self.bias.q_loc.unsqueeze(0).expand_as(y_mean) 84 | 85 | y_cov = xx.mm(torch.pow(torch.exp(self.weight.log_q_scale), 2).t()) 86 | if self.bias: 87 | y_cov += torch.pow(torch.exp(self.bias.log_q_scale), 2).unsqueeze(0).expand_as(y_cov) 88 | y_cov = torch.diag_embed(y_cov) 89 | return GaussianVar(y_mean, y_cov) 90 | 91 | 92 | class VariationalLinearReLU(VariationalLinear): 93 | def forward(self, input): 94 | x_var_diag = torch.diagonal(input.var, dim1=-2, dim2=-1) 95 | sqrt_x_var_diag = torch.sqrt(x_var_diag) 96 | mu = input.mean / (sqrt_x_var_diag + bu.EPSILON) 97 | 98 | def relu_covariance(x): 99 | mu1 = mu.unsqueeze(2) 100 | mu2 = mu1.transpose(1, 2) 101 | 102 | s11s22 = x_var_diag.unsqueeze(2) * x_var_diag.unsqueeze(1) 103 | rho = x.var / torch.sqrt(s11s22) 104 | rho = rho.clamp(-1 / (1 + bu.EPSILON), 1 / (1 + bu.EPSILON)) 105 | 106 | return x.var * bu.delta(rho, mu1, mu2) 107 | 108 | z_mean = sqrt_x_var_diag * bu.softrelu(mu) 109 | y_mean = z_mean.mm(self.weight.q_loc.t()) 110 | if self.bias: 111 | y_mean += self.bias.q_loc.unsqueeze(0).expand_as(y_mean) 112 | z_cov = relu_covariance(input) 113 | y_cov = self.forward_covariance(z_mean, z_cov) 114 | return GaussianVar(y_mean, y_cov) 115 | 116 | def forward_mcmc(self, input, n_samples=None, average=False): 117 | if n_samples is None: 118 | n_samples = 1 119 | 120 | sampled_w = self.weight.sample(n_samples, average) 121 | 122 | h = torch.matmul(input, sampled_w.transpose(1, 2)) 123 | 124 | if self.bias: 125 | sampled_b = self.bias.sample(n_samples, average) 126 | h += sampled_b.unsqueeze(1).expand_as(h) 127 | return h 128 | -------------------------------------------------------------------------------- /dvi/bayes_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .bayes_layers import VariationalLinearCertainActivations, VariationalLinearReLU 6 | from .variables import GaussianVar 7 | 8 | class MLP(nn.Module): 9 | def __init__(self, x_dim, y_dim, hidden_size=None): 10 | super(MLP, self).__init__() 11 | 12 | self.sizes = [x_dim] 13 | if hidden_size is not None: 14 | self.sizes += hidden_size 15 | self.sizes += [y_dim] 16 | self.make_layers() 17 | 18 | 19 | def make_layers(self): 20 | # layers = [VariationalLinearCertainActivations(self.sizes[0], self.sizes[1])] 21 | # for in_dim, out_dim in zip(self.sizes[1:-1], self.sizes[2:]): 22 | # print('in_dim:{}, out_dim:{}'.format(in_dim, out_dim)) 23 | # layers.append(VariationalLinearReLU(in_dim, out_dim)) 24 | # self.layers = nn.Sequential(*layers) 25 | 26 | self.layers = nn.Sequential( 27 | VariationalLinearCertainActivations(1, 128), 28 | VariationalLinearReLU(128, 128), 29 | VariationalLinearReLU(128, 2) 30 | ) 31 | # 32 | # self.layers = nn.Sequential(VariationalLinearCertainActivations(self.sizes[0], self.sizes[1])) 33 | # for in_dim, out_dim in zip(self.sizes[1:-1], self.sizes[2:]): 34 | # print('in_dim:{}, out_dim:{}'.format(in_dim, out_dim)) 35 | # self.layers.add_module('{}-{}'.format(in_dim, out_dim), VariationalLinearReLU(in_dim, out_dim)) 36 | 37 | def forward(self, input): 38 | return self.layers(input) 39 | 40 | def surprise(self): 41 | all_surprise = 0 42 | for layer in self.layers: 43 | all_surprise += layer.surprise() 44 | return all_surprise 45 | 46 | def forward_mcmc(self, input, n_samples=None, average=False): 47 | h = self.layers[0].forward_mcmc(input) 48 | for layer in self.layers[1:]: 49 | h = layer.forward_mcmc(F.relu(h), n_samples) 50 | return h 51 | 52 | 53 | 54 | class AdaptedMLP(object): 55 | def __init__(self, mlp, adapter, device=torch.device('cpu')): 56 | self.mlp = mlp.to(device) 57 | self.__dict__.update(mlp.__dict__) 58 | self.device = device 59 | self.make_adapters(adapter) 60 | 61 | 62 | def make_adapters(self, adapter): 63 | self.adapter = {} 64 | for ad in ['in', 'out']: 65 | self.adapter[ad] = { 66 | 'scale': torch.tensor(adapter[ad]['scale']).to(self.device), 67 | 'shift': torch.tensor(adapter[ad]['shift']).to(self.device) 68 | } 69 | 70 | def __call__(self, input): 71 | x_ad = self.adapter['in']['scale'] * input + self.adapter['in']['shift'] 72 | self.pre_adapt = self.mlp(x_ad) 73 | mean = self.adapter['out']['scale'] * self.pre_adapt.mean + self.adapter['out']['shift'] 74 | cov = self.adapter['out']['scale'].reshape(-1, 1) * self.adapter['out']['scale'].reshape(1, -1) * self.pre_adapt.var 75 | return GaussianVar(mean, cov) 76 | 77 | def __repr__(self): 78 | return "AdaptedMLP(\n" + self.mlp.__repr__() + ")" 79 | 80 | def surprise(self): 81 | return self.mlp.surprise() 82 | 83 | def parameters(self): 84 | return self.mlp.parameters() 85 | 86 | def mcmc(self, input, n_samples=None): 87 | x_ad = self.adapter['in']['scale'] * input + self.adapter['in']['shift'] 88 | self.pre_adapt = self.mlp.forward_mcmc(x_ad, n_samples) 89 | mean = self.adapter['out']['scale'] * self.pre_adapt + self.adapter['out']['shift'] 90 | return mean 91 | -------------------------------------------------------------------------------- /dvi/bayes_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | EPSILON = 1e-6 5 | HALF_EPSILON = EPSILON / 2.0 6 | 7 | def matrix_set_diag(input, diagonal, dim1=0, dim2=1): 8 | org_diag = torch.diag_embed(torch.diagonal(input, dim1=dim1, dim2=dim2), 9 | dim1=dim1, dim2=dim2) 10 | new_diag = torch.diag_embed(diagonal, dim1=dim1, dim2=dim2) 11 | return input - org_diag + new_diag 12 | 13 | def gaussian_cdf(x): 14 | return 0.5 * (1.0 + torch.erf(x * 1 / math.sqrt(2.0))) 15 | 16 | def g(rho, mu1, mu2): 17 | one_plus_sqrt_one_minus_rho_sqr = 1.0 + torch.sqrt(1.0 - rho * rho) 18 | a = torch.asin(rho) - rho / one_plus_sqrt_one_minus_rho_sqr 19 | safe_a = torch.abs(a) + HALF_EPSILON 20 | safe_rho = torch.abs(rho) + EPSILON 21 | 22 | A = a / (2.0 * math.pi) 23 | sxx = safe_a * one_plus_sqrt_one_minus_rho_sqr / safe_rho 24 | one_ovr_sxy = (torch.asin(rho) - rho) / (safe_a * safe_rho) 25 | 26 | return A * torch.exp(-(mu1 * mu1 + mu2 * mu2) / (2.0 * sxx) + one_ovr_sxy * mu1 * mu2) 27 | 28 | 29 | def delta(rho, mu1, mu2): 30 | return gaussian_cdf(mu1) * gaussian_cdf(mu2) + g(rho, mu1, mu2) 31 | 32 | 33 | def standard_gaussian(x): 34 | return 1.0 / math.sqrt(2.0 * math.pi) * torch.exp(- x * x / 2.0) 35 | 36 | def softrelu(x): 37 | return standard_gaussian(x) + x * gaussian_cdf(x) 38 | 39 | def anneal(epoch, warmup=14000, anneal=1000): 40 | return 1.0 * max(min((epoch - warmup) / anneal, 1.0), 0.0) 41 | 42 | -------------------------------------------------------------------------------- /dvi/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class ToyDataset(Dataset): 8 | # data_size = {'train': 500, 'valid': 100, 'test': 100} 9 | def __init__(self, model=None, noise=None, 10 | data_size=500, sampling=True, transform=None): 11 | if model is None: 12 | model = self.base_model 13 | self.model = model 14 | 15 | if noise is None: 16 | noise = self.noise_model 17 | self.noise = noise 18 | 19 | self.data_size = data_size 20 | 21 | self.sampling = sampling 22 | 23 | self.transform = transform 24 | 25 | self.create_data() 26 | 27 | def __len__(self): 28 | return self.data_size 29 | 30 | def __getitem__(self, idx): 31 | x, y = self.data[idx] 32 | 33 | if self.transform: 34 | x = self.transform(x) 35 | 36 | return x, y 37 | 38 | def __repr__(self): 39 | return "Toy Dataset" 40 | 41 | def base_model(self, x): 42 | return - (x + 0.5) * np.sin(3 * np.pi * x) 43 | 44 | def noise_model(self, x): 45 | return 0.45 * (x + 0.5) ** 2 46 | 47 | def sample_data(self, x): 48 | return self.model(x) + np.random.normal(0, self.noise(x)) 49 | 50 | def create_data(self): 51 | if self.sampling: 52 | xs = np.random.rand(self.data_size, 1) - 0.5 53 | else: 54 | xs = np.arange(-1, 1, 1 / 100) 55 | ys = self.sample_data(xs) 56 | 57 | self.data = [(torch.from_numpy(x).float(), torch.tensor(torch.from_numpy(y).item()).float()) 58 | for x, y in zip(xs, ys)] 59 | 60 | -------------------------------------------------------------------------------- /dvi/invgamma.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | 3 | import torch 4 | from torch.distributions import constraints 5 | from torch.distributions.exp_family import ExponentialFamily 6 | from torch.distributions import MultivariateNormal, Laplace 7 | from torch.distributions.utils import broadcast_all 8 | 9 | 10 | def _standard_gamma(concentration): 11 | return torch._standard_gamma(concentration) 12 | 13 | 14 | class InverseGamma(ExponentialFamily): 15 | arg_constraints = {'concentration': constraints.positive, 16 | 'scale': constraints.positive, 17 | 'rate': constraints.positive} 18 | support = constraints.positive 19 | has_rsample = True 20 | _mean_carrier_measure = 0 21 | 22 | @property 23 | def mean(self): 24 | return self.scale / (self.concentration - 1) 25 | 26 | @property 27 | def variance(self): 28 | return torch.pow(self.scale, 2) / torch.pow(self.concentration - 1, 2) / (self.constraints - 2) 29 | 30 | def __init__(self, concentration, scale=None, rate=None, validate_args=None): 31 | if rate is not None: 32 | scale = rate 33 | 34 | self.concentration, self.scale = broadcast_all(concentration, scale) 35 | if isinstance(concentration, Number) and isinstance(scale, Number): 36 | batch_shape = torch.Size() 37 | else: 38 | batch_shape = self.concentration.size() 39 | super(InverseGamma, self).__init__(batch_shape, validate_args=validate_args) 40 | 41 | def expand(self, batch_shape, _instance=None): 42 | new = self._get_checked_instance(InverseGamma, _instance) 43 | batch_shape = torch.Size(batch_shape) 44 | new.concentration = self.concentration.expand(batch_shape) 45 | new.scale = self.scale.expand(batch_shape) 46 | super(InverseGamma, new).__init__(batch_shape, validate_args=False) 47 | new._validate_args = self._validate_args 48 | return new 49 | 50 | def rsample(self, sample_shape=torch.Size()): 51 | shape = self._extended_shape(sample_shape) 52 | value = 1 / _standard_gamma(self.concentration.expand(shape)) * self.scale.expand(shape) 53 | value.detach().clamp_(min=torch.finfo(value.dtype).tiny) 54 | return value 55 | 56 | def log_prob(self, value): 57 | if self._validate_args: 58 | self._validate_sample(value) 59 | return (self.concentration * torch.log(self.scale) - 60 | (self.concentration + 1) * torch.log(value) - 61 | self.scale / value - torch.lgamma(self.concentration)) 62 | 63 | def entropy(self): 64 | return (self.concentration + torch.log(self.scale) + torch.lgamma(self.concentration) - 65 | (1 + self.concentration) * torch.digamma(self.concentration)) 66 | 67 | @property 68 | def _natural_params(self): 69 | return (-self.concentration - 1, -self.scale) 70 | 71 | def _log_normalizer(self, x, y): 72 | return torch.lgamma(- x - 1) + (- x - 1) * torch.log(-y.reciprocal()) 73 | -------------------------------------------------------------------------------- /dvi/kl.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.distributions as tdist 4 | from torch.distributions.kl import register_kl 5 | 6 | from .invgamma import InverseGamma 7 | 8 | one_ovr_sqrt2pi = 1.0 / math.sqrt(2.0 * math.pi) 9 | one_ovr_sqrt2 = 1.0 / math.sqrt(2.) 10 | 11 | 12 | def _standard_gaussian(x): 13 | return one_ovr_sqrt2pi * torch.exp(- x * x / 2.0) 14 | 15 | 16 | @register_kl(tdist.Normal, InverseGamma) 17 | def kl_normal_invgamma(p, q): 18 | # p: loc/scale, q: concentration, scale 19 | m = torch.numel(p.loc) 20 | S = p.scale.pow(2) + p.loc.pow(2) 21 | m_plus_2alpha_plus_2 = m + 2.0 * q.concentration + 2.0 22 | S_plus_2beta = S + 2.0 * q.scale / m 23 | 24 | term1 = torch.log(torch.sum(S_plus_2beta) / m_plus_2alpha_plus_2) 25 | term2 = S * (m_plus_2alpha_plus_2 / torch.sum(S_plus_2beta)) 26 | term3 = -(1 + torch.log(p.scale.pow(2))) 27 | return 0.5 * (term1 + term2 + term3) 28 | 29 | 30 | @register_kl(tdist.Normal, tdist.Laplace) 31 | def kl_normal_laplace(p, q): 32 | sigma = p.scale 33 | mu_ovr_sigma = p.loc / sigma 34 | tmp = 2 * _standard_gaussian(mu_ovr_sigma) + mu_ovr_sigma * torch.erf(mu_ovr_sigma * one_ovr_sqrt2) 35 | tmp *= sigma / q.scale 36 | tmp += 0.5 * torch.log(2 * q.scale * q.scale / (math.pi * p.scale)) - 0.5 37 | return tmp 38 | -------------------------------------------------------------------------------- /dvi/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | 7 | class GLLLoss(nn.Module): 8 | """Gaussian Log-Likelihood Loss 9 | """ 10 | def __init__(self, style='heteroschedastic', method='bayes', homo_logvar_scale=0.4): 11 | super(GLLLoss, self).__init__() 12 | 13 | self.style = style 14 | self.method = method 15 | self.homo_logvar_scale = homo_logvar_scale 16 | self.gaussian_loglikelihood = (self.heteroschedastic_gaussian_loglikelihood 17 | if self.style == 'heteroschedastic' 18 | else self.homoschedastic_gaussian_loglikelihood) 19 | 20 | def forward(self, pred, target): 21 | log_likelihood = self.gaussian_loglikelihood(pred, target) 22 | return log_likelihood 23 | 24 | def extra_repr(self): 25 | return 'style={}, method={}{}'.format(self.style, self.method, 26 | ', scale={}'.format(self.homo_logvar_scale) 27 | if self.style != 'heteroschedastic' else '') 28 | 29 | def heteroschedastic_gaussian_loglikelihood(self, pred, target): 30 | log_variance = pred.mean[:, 1].reshape(-1) 31 | mean = pred.mean[:, 0].reshape(-1) 32 | 33 | if self.method.lower().strip() == 'bayes': 34 | sll = pred.var[:, 1, 1].reshape(-1) 35 | smm = pred.var[:, 0, 0].reshape(-1) 36 | sml = pred.var[:, 0, 1].reshape(-1) 37 | else: 38 | sll = torch.tensor(0.0).to(target.device) 39 | smm = torch.tensor(0.0).to(target.device) 40 | sml = torch.tensor(0.0).to(target.device) 41 | return self.gaussian_loglikelihood_core(target, mean, log_variance, smm, sml, sll) 42 | 43 | def homoschedastic_gaussian_loglikelihood(self, pred, target): 44 | log_variance = torch.tensor(self.homo_logvar_scale).to(target.device) 45 | mean = pred.mean[:, 0].reshape(-1) 46 | sll = torch.tensor(0.0).to(target.device) 47 | sml = torch.tensor(0.0).to(target.device) 48 | if self.method.lower().strip() == 'bayes': 49 | smm = pred.var[:, 0, 0].reshape(-1) 50 | else: 51 | smm = torch.tensor(0.0).to(target.device) 52 | return self.gaussian_loglikelihood_core(target, mean, log_variance, smm, sml, sll) 53 | 54 | def gaussian_loglikelihood_core(self, target, mean, log_variance, smm, sml, sll): 55 | return (-0.5 * (torch.tensor(math.log(2.0 * math.pi)) + log_variance 56 | + torch.exp(-log_variance + 0.5 * sll) * (smm + (mean - sml - target) ** 2))) 57 | -------------------------------------------------------------------------------- /dvi/plot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | 4 | def toydata_result_plot(trainloader, model=None): 5 | mean_model = trainloader.dataset.base_model 6 | std_model = trainloader.dataset.noise_model 7 | 8 | xs, ys = iter(trainloader).next() 9 | 10 | train_x = torch.arange(torch.min(xs.reshape(-1)), torch.max(xs.reshape(-1)), 1/100).cpu() 11 | 12 | plt.plot(train_x.numpy(), mean_model(train_x).numpy(), 'red', label='data mean') 13 | plt.fill_between(train_x.cpu().numpy(), 14 | (mean_model(train_x) - std_model(train_x)).numpy(), 15 | (mean_model(train_x) + std_model(train_x)).numpy(), 16 | color='orange', alpha=1, label='data 1-std') 17 | plt.plot(xs.cpu().numpy(), ys.cpu().numpy(), 'r.', alpha=0.2, label='train sample') 18 | 19 | 20 | if model is not None: 21 | with torch.no_grad(): 22 | test_x = torch.arange(-1, 1, 1/100).reshape(-1, 1) 23 | pred = model(test_x) 24 | y_mean = pred.mean[:,0].cpu() 25 | ell_mean = pred.mean[:,1].cpu() 26 | y_var = pred.var[:,0,0].cpu() 27 | ell_var = pred.var[:,1,1].cpu() 28 | 29 | heteroskedastic_part = torch.exp(0.5 * ell_mean) 30 | full_std = torch.sqrt(y_var + torch.exp(ell_mean + 0.5 * ell_var)) 31 | 32 | plt.plot(test_x.cpu().numpy(), y_mean.numpy(), label='model mean') 33 | plt.fill_between(test_x.cpu().reshape(-1).numpy(), 34 | (y_mean - heteroskedastic_part).numpy(), 35 | (y_mean + heteroskedastic_part).numpy(), 36 | color='g', alpha = 0.2, label='$\ell$ contrib') 37 | plt.fill_between(test_x.cpu().reshape(-1).numpy(), 38 | (y_mean - full_std).numpy(), 39 | (y_mean + full_std).numpy(), 40 | color='b', alpha = 0.2, label='model 1-std') 41 | 42 | plt.xlabel('x') 43 | plt.ylabel('y') 44 | plt.ylim([-3,2]) 45 | plt.legend() -------------------------------------------------------------------------------- /dvi/variables.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.distributions as tdist 7 | 8 | from .invgamma import InverseGamma 9 | from .kl import kl_normal_invgamma, kl_normal_laplace 10 | 11 | class GaussianVar(object): 12 | def __init__(self, mean, var): 13 | self.mean = mean 14 | self.var = var 15 | self.shape = mean.shape 16 | 17 | 18 | class Parameter(nn.Module): 19 | def __init__(self, prior, approximation, q_loc, log_q_scale): 20 | super(Parameter, self).__init__() 21 | 22 | self.prior = prior 23 | self.approximation = approximation 24 | 25 | self.q_loc = q_loc 26 | self.log_q_scale = log_q_scale 27 | 28 | def q(self): 29 | return self.approximation(loc=self.q_loc, scale=torch.exp(self.log_q_scale)) 30 | 31 | def __repr__(self): 32 | args_string = 'Prior: {}\n Variational: {}'.format( 33 | self.prior, 34 | self.q() 35 | ) 36 | return self.__class__.__name__ + '(\n ' + args_string + '\n)' 37 | 38 | def surprise(self): 39 | q = self.q() 40 | p = self.prior 41 | return tdist.kl_divergence(q, p) 42 | 43 | def sample(self, n_sample=None, average=False): 44 | if n_sample is None: 45 | n_sample = 1 46 | samples = self.q().rsample((n_sample,)) 47 | return samples 48 | 49 | 50 | def get_variance_scale(initialization_type, shape): 51 | if initialization_type == "standard": 52 | prior_var = 1.0 53 | elif initialization_type == "wide": 54 | prior_var = 100.0 55 | elif initialization_type == "narrow": 56 | prior_var = 0.01 57 | elif initialization_type == "glorot": 58 | prior_var = (2.0 / (shape[-1] + shape[-2])) 59 | elif initialization_type == "xavier": 60 | prior_var = 1.0/shape[-1] 61 | elif initialization_type == "he": 62 | prior_var = 2.0/shape[-1] 63 | elif initialization_type == "wider_he": 64 | prior_var = 5.0/shape[-1] 65 | else: 66 | raise NotImplementedError('prior type "%s" not recognized' % initialization_type) 67 | return prior_var 68 | 69 | 70 | def gaussian_init(loc, scale, shape): 71 | return loc + scale * torch.randn(*shape) 72 | 73 | 74 | def laplace_init(loc, scale, shape): 75 | return torch.from_numpy(np.random.laplace(loc, scale/np.sqrt(2.0), size=shape).astype(np.float32)) 76 | 77 | 78 | def make_weight_matrix(shape, prior_type, variance): 79 | """ 80 | Args: 81 | shape (list, required): The shape of weight matrix. It should be `(out_features, in_features)`. 82 | prior_type (list, required): Prior Type. It should be `[prior, weight_scale, bias_scale]` 83 | `["gaussian", "wider_he", "wider_he"]`. 84 | """ 85 | variance = get_variance_scale(variance.strip().lower(), shape) 86 | stddev = torch.sqrt(torch.ones(shape) * variance).float() 87 | log_stddev = nn.Parameter(torch.log(stddev)) 88 | stddev = torch.exp(log_stddev) 89 | 90 | prior = prior_type.strip().lower() 91 | 92 | if prior == 'empirical': 93 | a = 4.4798 94 | alpha = a 95 | beta = (1 + a) * variance 96 | 97 | prior = InverseGamma(alpha, beta) 98 | 99 | mean = nn.Parameter(torch.Tensor(*shape)) 100 | nn.init.normal_(mean, 0.0, math.sqrt(variance)) 101 | return Parameter(prior, tdist.Normal, mean, log_stddev) 102 | 103 | elif prior == 'gaussian' or prior == 'normal': 104 | init_function = gaussian_init 105 | prior_generator = tdist.Normal 106 | elif prior == 'laplace': 107 | init_function = laplace_init 108 | prior_generator = tdist.Laplace 109 | else: 110 | raise NotImplementedError('prior type "{}" not recognized'.format(prior)) 111 | 112 | mean = nn.Parameter(init_function(0.0, math.sqrt(variance), shape)) 113 | 114 | prior_loc = torch.zeros(*shape) 115 | prior_scale = torch.ones(*shape) * math.sqrt(variance) 116 | prior = prior_generator(prior_loc, prior_scale) 117 | 118 | return Parameter(prior, tdist.Normal, mean, log_stddev) 119 | 120 | 121 | def make_bias_vector(shape, prior_type, variance): 122 | fudge_factor = 10.0 123 | variance = get_variance_scale(variance.strip().lower(), shape) 124 | stddev = torch.sqrt(torch.ones(shape[-2],) * variance / fudge_factor) 125 | log_stddev = nn.Parameter(torch.log(stddev)) 126 | stddev = torch.exp(log_stddev) 127 | 128 | prior = prior_type.strip().lower() 129 | 130 | if prior == 'empirical': 131 | a = 4.4798 132 | alpha = a 133 | beta = (1 + a) * variance 134 | 135 | prior = InverseGamma(alpha, beta) 136 | 137 | mean = nn.Parameter(torch.zeros(shape[-2],)) 138 | return Parameter(prior, tdist.Normal, mean, log_stddev) 139 | else: 140 | raise NotImplementedError('prior type "{}" not recognized'.format(prior)) 141 | -------------------------------------------------------------------------------- /main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "" 12 | ] 13 | }, 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "%matplotlib inline\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "\n", 23 | "from IPython.display import clear_output\n", 24 | "\n", 25 | "import math\n", 26 | "import numpy as np\n", 27 | "import torch\n", 28 | "import torch.optim as optim\n", 29 | "from torch.utils.data import DataLoader\n", 30 | "\n", 31 | "import dvi.bayes_utils as bu\n", 32 | "from dvi.dataset import ToyDataset\n", 33 | "from dvi.bayes_models import MLP, AdaptedMLP\n", 34 | "from dvi.loss import GLLLoss\n", 35 | "from dvi.plot import toydata_result_plot\n", 36 | "\n", 37 | "np.random.seed(1234)\n", 38 | "torch.manual_seed(1234)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "TRAIN_SIZE = 500\n", 48 | "X_DIM = 1\n", 49 | "Y_DIM = 2\n", 50 | "LEARNING_RATE = 1e-3\n", 51 | "EPOCHS = 20000\n", 52 | "ADAPTER = {\n", 53 | " 'in': {\"scale\": [[1.0]], \"shift\": [[0.0]]},\n", 54 | " 'out': {\"scale\": [[1.0, 0.83]], \"shift\": [[0.0, -3.5]]}\n", 55 | " }\n", 56 | "WARMUP = 14000\n", 57 | "ANNEAL = 1000\n", 58 | "\n", 59 | "if torch.cuda.is_available():\n", 60 | " DEVICE = torch.device('cuda')\n", 61 | " torch.set_default_tensor_type(torch.cuda.FloatTensor)\n", 62 | "else:\n", 63 | " DEVICE = torch.device('cpu')" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "## Dataset\n", 71 | "\n", 72 | "First we generate a toy dataset according to:\n", 73 | "\\begin{equation} \n", 74 | "y = -(x+0.5)\\sin(3\\pi x) + \\eta\n", 75 | "\\end{equation}\n", 76 | "\n", 77 | "Where the noise is generated according to:\n", 78 | "\n", 79 | "\\begin{equation}\n", 80 | " \\eta = 0.45(x + 0.5)^2\n", 81 | "\\end{equation}" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "trainset = ToyDataset(data_size=TRAIN_SIZE, sampling=True)\n", 91 | "trainloader = DataLoader(trainset, batch_size=TRAIN_SIZE, shuffle=True)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 4, 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "data": { 101 | "image/png": "\n", 102 | "text/plain": [ 103 | "
" 104 | ] 105 | }, 106 | "metadata": { 107 | "needs_background": "light" 108 | }, 109 | "output_type": "display_data" 110 | } 111 | ], 112 | "source": [ 113 | "toydata_result_plot(trainloader)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "## Model" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 5, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "AdaptedMLP(\n", 132 | "MLP(\n", 133 | " (layers): Sequential(\n", 134 | " (0): VariationalLinearCertainActivations(\n", 135 | " (weight): Parameter(\n", 136 | " Prior: InverseGamma(concentration: 4.479800224304199, scale: 27.39900016784668)\n", 137 | " Variational: Normal(loc: torch.Size([128, 1]), scale: torch.Size([128, 1]))\n", 138 | " )\n", 139 | " (bias): Parameter(\n", 140 | " Prior: InverseGamma(concentration: 4.479800224304199, scale: 27.39900016784668)\n", 141 | " Variational: Normal(loc: torch.Size([128]), scale: torch.Size([128]))\n", 142 | " )\n", 143 | " )\n", 144 | " (1): VariationalLinearReLU(\n", 145 | " (weight): Parameter(\n", 146 | " Prior: InverseGamma(concentration: 4.479800224304199, scale: 0.21405468881130219)\n", 147 | " Variational: Normal(loc: torch.Size([128, 128]), scale: torch.Size([128, 128]))\n", 148 | " )\n", 149 | " (bias): Parameter(\n", 150 | " Prior: InverseGamma(concentration: 4.479800224304199, scale: 0.21405468881130219)\n", 151 | " Variational: Normal(loc: torch.Size([128]), scale: torch.Size([128]))\n", 152 | " )\n", 153 | " )\n", 154 | " (2): VariationalLinearReLU(\n", 155 | " (weight): Parameter(\n", 156 | " Prior: InverseGamma(concentration: 4.479800224304199, scale: 0.21405468881130219)\n", 157 | " Variational: Normal(loc: torch.Size([2, 128]), scale: torch.Size([2, 128]))\n", 158 | " )\n", 159 | " (bias): Parameter(\n", 160 | " Prior: InverseGamma(concentration: 4.479800224304199, scale: 0.21405468881130219)\n", 161 | " Variational: Normal(loc: torch.Size([2]), scale: torch.Size([2]))\n", 162 | " )\n", 163 | " )\n", 164 | " )\n", 165 | "))" 166 | ] 167 | }, 168 | "execution_count": 5, 169 | "metadata": {}, 170 | "output_type": "execute_result" 171 | } 172 | ], 173 | "source": [ 174 | "mlp = MLP(X_DIM, Y_DIM)\n", 175 | "model = AdaptedMLP(mlp, ADAPTER, device=DEVICE)\n", 176 | "model" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "## Training " 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 6, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "criterion = GLLLoss()\n", 193 | "\n", 194 | "optimizer = optim.Adam(model.parameters(),\n", 195 | " lr=LEARNING_RATE)\n" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 7, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "def train(epoch, model, criterion, dataloader, optimizer):\n", 205 | " model.mlp.train()\n", 206 | "\n", 207 | " for xs, ys in dataloader:\n", 208 | " xs, ys = xs.to(DEVICE), ys.to(DEVICE)\n", 209 | "\n", 210 | " optimizer.zero_grad()\n", 211 | "\n", 212 | " pred = model(xs)\n", 213 | "\n", 214 | " kl = model.surprise()\n", 215 | "\n", 216 | " log_likelihood = criterion(pred, ys)\n", 217 | " batch_log_likelihood = torch.mean(log_likelihood)\n", 218 | "\n", 219 | " lmbd = bu.anneal(epoch, WARMUP, ANNEAL)\n", 220 | "\n", 221 | " loss = lmbd * kl / TRAIN_SIZE - batch_log_likelihood\n", 222 | "\n", 223 | " loss.backward()\n", 224 | " torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)\n", 225 | "\n", 226 | " optimizer.step()\n", 227 | "\n", 228 | " accuracy = torch.mean(torch.abs(pred.mean[:, 0].reshape(-1) - ys))\n", 229 | " if epoch % 20 == 0:\n", 230 | " print(\"Epoch {}: GLL={:.4f}, KL={:.4f}(anneal:{}) | MAE={:.4f}\".format(\n", 231 | " epoch, batch_log_likelihood.item(), kl.item()/500, lmbd, accuracy))" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 8, 237 | "metadata": {}, 238 | "outputs": [ 239 | { 240 | "data": { 241 | "image/png": "\n", 242 | "text/plain": [ 243 | "
" 244 | ] 245 | }, 246 | "metadata": { 247 | "needs_background": "light" 248 | }, 249 | "output_type": "display_data" 250 | } 251 | ], 252 | "source": [ 253 | "toydata_result_plot(trainloader, model)\n", 254 | "plt.show()\n", 255 | "\n", 256 | "for epoch in range(1, EPOCHS+1):\n", 257 | " train(epoch, model, criterion, trainloader, optimizer)\n", 258 | " if epoch % 100 == 0:\n", 259 | " clear_output()\n", 260 | " toydata_result_plot(trainloader, model)\n", 261 | " plt.show()\n" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "metadata": {}, 267 | "source": [ 268 | "## Comparison with MC\n", 269 | "\n", 270 | "We compare the output activations of the BNN as predicted using our deterministic approximation with the activations predicted by Monte Carlo approximation (see Figure 3 of the paper)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 9, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "name": "stderr", 280 | "output_type": "stream", 281 | "text": [ 282 | "/home/makotok/projects/deterministic-variational-inference-pytorch/dvi/bayes_models.py:66: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", 283 | " 'scale': torch.tensor(adapter[ad]['scale']).to(self.device),\n", 284 | "/home/makotok/projects/deterministic-variational-inference-pytorch/dvi/bayes_models.py:67: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", 285 | " 'shift': torch.tensor(adapter[ad]['shift']).to(self.device)\n" 286 | ] 287 | } 288 | ], 289 | "source": [ 290 | "# from matplotlib import rc\n", 291 | "# rc('text', usetex=True)\n", 292 | "torch.set_default_tensor_type(torch.FloatTensor)\n", 293 | "model.device = torch.device('cpu')\n", 294 | "model.make_adapters(model.adapter)\n", 295 | "import numpy as np\n", 296 | "n_sample = 10000\n", 297 | "x = 0.25\n", 298 | "with torch.no_grad():\n", 299 | " model.mlp.cpu()\n", 300 | " test_x = torch.ones(n_sample, 1) * x\n", 301 | "\n", 302 | " samples = model.mcmc(test_x, n_sample)\n", 303 | " tmp = model(test_x)\n", 304 | " approx = [tmp.mean, tmp.var]\n", 305 | " \n", 306 | "def gaussian1d(mean, var):\n", 307 | " x_axis = torch.linspace(-5,5,1000)\n", 308 | " return x_axis.cpu().numpy(), (1.0 / torch.sqrt(2.0 * math.pi * var) * torch.exp(-1.0 / (2.0 * var) * (x_axis - mean)**2)).cpu().numpy()" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 23, 314 | "metadata": {}, 315 | "outputs": [ 316 | { 317 | "data": { 318 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAM8AAACdCAYAAAATxyyjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAEYNJREFUeJzt3XmQFGWax/Hv0013y9UgoIBi0wLK2dCIKIgcOqKoE6DjxY6zoWE4sQY6bsRqjAexw7qzG65OzAxjeLGuoYCBCl4oXgwKDY3QIC23oIhKg6igCDpAX/XsH281NNhHVXZmZRb1fCKKriMr86GqfpWZb735pqgqxpjkZYVdgDHpysJjjEcWHmM8svAY45GFxxiPLDzGeGThMcYjC48xHll4jPGoVSoW0qVLFy0sLEzFooxJ2po1a/aq6inJPi8l4SksLOTDDz9MxaKMSZqIfOnlebbZZoxHFh5jPLLwGOORhSfFCgp6IiKICAUFPcMux7RAShoMzFEVFTtYsGI3AL8c2T3kakxL2JrHGI8sPMZ4ZOExxiMLT4hycvOONB5YA0L6sQaDEFVXVR5pPABrQEg3tuYxxiMLjzEeWXiM8cjCY4xHFh5jPLLwBKx+XzYRCbsc4yNrqg5Y/b5sYM3RJxJb8xjjkYXHGI8sPCm0qyIb+G/emNeG6qqwqzEtZfs8KfLyc2145rF84B5m/CWbV+e0BfqHXZZpAVvzpMDGtTnMfKI9I8YcBk7nv/72PZWVAM9RXR1yccYzC0/g8vnzf3Sk62m1/Nsf9gPfUHxeFbf//gBwDnOfbRd2gcYjC0/g/oU932Rz17T9tGl79BSWF4yrBJ5j7sy2fLfH3oZ0ZO9agNwm2Z0MHlZJv0ENbZ9NIxaDt19tk+LKjB8sPAGaOxegB1f/+h+NTLGd4aMqefu11tb6loYsPAGaPh1gM8NGNJ6MidcdZP++bJYuOilldRl/WHgC8skn4IbnfoqsJl7lIcOr6NGzhoVvtLHDstOM/c4TkHnzjlwD7ml0OhEYM/4Qzz/dDtWOLFix9shj1g8u2mzNE5C5c2HUKIBdzU476qJKVAW4KuiyjI8sPAHYsgXWr4frr09s+oIza+hRUANcG2hdxl8WngC89JLbHLvmmsSmF4FRFx8GxrH/BzvmJ11YeALw5pswfDicfnriz7lg3GGgFauXW6tburDw+Oz772HVKpgwIbnn9Tq7BthNeVluIHUZ/1l4fLZoEcRiyYfHHaG9kLWr8ojFgqjM+M3C47N33oGTT3abbcl7lwP7s/hsq/2CkA4sPD5ShVmzvmbfvrnk5HgZ8GMRAB+V5flfnPGdhcdHGzdCbW037rz/Mhas2H3MwB+J2UPvvtWUW3jSgoXHR4sXu7/F51Z6nsfQ8yr5eEMOhw5ak3XUWXh8tGQJwHZO7e59j3/wsCpqa4WtG3P8KssExMLjk1gMSkoAlrRoPv0GVZOVpWxca03WUWfh8cnGje43npaGp01bpXffGgtPGrDw+KRuf6el4QEYWFzF1k05gAUoypIOj4i0FZHsIIpJZ0uWQK9eABUtnteg4iqqqwTw9GORSZFmwyMiWSLyaxF5U0S+BbYAu0Vks4j8SUT6BF9mtMVisHQpjB3rz/wGDKk78nSMPzM0gUhkzbMY6A3cB3RT1TNU9VTgQmAl8JCI/CbAGiNvyxa3vzPGp896fgelZ69qYLQ/MzSBSKQfyCWqWg0gIjlADEBVvwdeBl6O35+xSkvd3wsv9G+eg4qr+HL7KGpqoJX11omkZtc89YLzFLBDRCpEpExEnhKR39WfJlOVlkLXrtC7t3/zHDi0GsgnJ2eYjWkQUcl8p40BeqhqrYicDgwBBgdTVnopLXVrHT/PXTUwvt9z652LueqfDgI2pkHUJNPaVgZ0BlDVXar6lqr+TzBlpY9du+Dzz/3dZAPofEoM2Ga/90RYMuGZAZSIyN0iMlpEOgRVVDoJYn/nqKVsXpdrx/dEVDLheQ6YhdvUmwJ8ICKfBVJVGikthbZtobg4iLkv5cD+LCq+sBaDKErmXdmpqg/Wv0NEMr7vfGkpjBgRVIvYUgA2rc2hZ6+aIBZgWiCZNc9aEfnX+neoqve+9yeA/fvdEFPBbLIBfE6nLrVsWmf7PVGUzPdlV+ASEbkHKAfWAWtVdV7TTztxrVzpehcEFx7X6rbZwhNJCa95VPV6Ve0PnAn8AfgEOD+owtJBaSlkZ8P5Ab4KA4ZUs+ebbL792vrwRk2zax4REVU9clam+KZaefzS4DSZorTUNRS0bx/cMup+77FNt+hJqG+biPxORArq3ykiuSJysYjMBG4KprzoqqqCsrJgN9kAevauoU3bmG26RVAi+zwTgFuA50WkF7APaI0L3kJguqp+FFyJ0VReDocOBR+e7GzoX1TNpnUZ3X0wkpoNj6oeBh4HHo93AO0CHFLVH4IuLsqWulZkRqeg4/PAIVXMmtEe6BT8wkzCEt4LFZFPcSebuQ34hYj42A0y/ZSUQN++MHx4z2NOSBWEo8f3XBDI/I03yTThzAd2AF8D44ENIrJDRFaIyIxAqouo2lrXWDB2LFRU7DgyRlvy47Ql5uwB1bTKUez4nmhJJjyXquqdqvqEqt4GXATMAa4DXg+kuohatw4OHPDvyNHm5ObBWf2qcccfmqhIJjzficiQuhuqWgZcrqo7VfVN/0uLLjfElH9HjiZiYHEVcC6HDqVumaZpyfQw+C0wS0Q2AWuB/kBGvpUlJW6wjx49UrfMAYOrgXasWpW6NZ5pWjI9DLbhthveBroB24ArA6orsmIxWLYs9R/gAYNdo8GyZaldrmlcUn2BVTUGvBK/ZKRNm9xgH6kOT7t8BTawdGlRahdsGmUdppJUt78TzqbTEpYvd70bTPgsPEkqKYEzzoCeoYzFsZiDB2H16jCWbY5n4UmC6tHBDQP6PbQZJYjUH9rXhMnCk4StW+Hbb8Ns7fqewYMtPFFh4UlCuPs7kJObx7p1f+X99w8hkmfjuIXMwpOExYuhe3foE9Lo3NVVlfz7wzcDrXnwsa+oqNgRTiEGsPAkLBZzp4kfPz6s/R1n0NAqsrKVtaszfuyV0Fl4ElReDt99B5deGm4dbdspfQdUU15mB8eFzcKToIUL3d/x48OtA+Cc8yvZtiWH+ACuJiQWngQtXAhDh8Kpp4ZdCZwzogpVwR0ZYsJi4UnAjz/CBx+Ev8lWp0+/atrnx4DLwi4lo1l4ErB4MVRXRyc82dlQPLwSuJTMG7MoOiw8CZg/Hzp0cOMVFBQEf9h1IoaNrAJOo7y82UlNQGwE8WbU1sIbb8AVV0BOztHDruuEdc6c4RccBtoxf342w4aFUkLGszVPM8rKYM8emDgx7EqO1eFkBUp57bWwK8lcFp5mzJ/v1jiXXx52JQ15jQ0bYPv2sOvITBaeJqi68Iwb5/Z5ome++3d+yGVkKAtPE9avdz2pr7467Eoa8zlFRfDyy2HXkZksPE2YM8edtOq668KupHGTJ8Py5fDFF2FXknksPI2IxeD55+Gyy6BLl7CraVhObh5Tp7rDEs48c6odopBiFp5GlJZCRQXceGPYlTSuuqqSBSvKGFRcRY+eD9ghCilm4WnE7NnuRL1Ra6JuyLgJh9j5ZSvAfvBJJQtPA374we3v3HCDC1DUXXjxYXJzFbg17FIyioWnAbNmwcGDcPvtYVeSmHbtlXGXHQL+mX37wq4mc1h4jqMKjz/uTg9/zjnR6cvWnF9eexBoyzPPhF1J5rDwHGfhQvfbzpQp7nYqTiHih15n1wDLeOwx1x/PBM/CU48qPPCAG8D9+uvDrsaL6WzfDi++GHYdmcHCU8/ChbBiBUydCnlpOb7GqwweDNOmQU1N2LWc+Cw8caruQ1dQALfcEnY1Xil//CNs2wYzZ4Zdy4nPwhM3c6Y7/GDaNMhN04FpcnLzmDRJgJXceutX9OhhZ1QIkoUH2LsX7r4bRo2Cm28OuxrvXI+D3Ux/pg9ZWd3ZtStN2trTVMaHRxXuuAP274cZMyDrBHhF+vSrYdINB4HbbFzrAJ0AH5WWeeQR1zr1wAMwcGDY1fjnxt/+BGxl8mTXR8/4L6PD8847cNddcNVVcO+9YVfjr5NaKzCJw4fd/+/AgbArOvFkbHgWLIBJk6CoCJ599ujmWrr0KEjMVubMcQf1XXKJGy7Y+CfjwhOLwYMPum/joiJ4771jD7FOlx4FibrySnjlFRegkSOxoap8lFHhKSuDMWPg/vvhmmvg/fehU6ewqwpOTm4eIsLEiUJl5Wi2b9/NiBFw331YB1IfnPDh2bvXHZtz0UWus+enn8LTT8MLL0B+ftjVBauu6dpd5lFbO4jJk+Ghh6Cw0PXfW77ceiN4lbJBD2Mx97dueNhk/jb12MGD8NNP7rJnD+zc6VqXtm93mygbNrjpevVym2u33w7t2x+tq6CgZ8YcgZmT+w9mzxagiAMHfs+TT17LE0+cRH6+G8S+Xz93KSyEzp3dWrljR/ejcU7OsZe03x30QUrCs2aNG185lbp1g+Ji+NWvYMIEGD7cNQo0FJYojACaCnVrojpXjSmgpnosBw78gpKSQSxbNpBY7OSE5pWd7S4NhSjR+5KZNophFU3BSOEi8iOwNfAFNa8LsNdqAKJRRxRqAOirqu2bn+xYqdps26qq56ZoWY0SkQ/DriMKNUSljijUUFeHl+ed8A0GxgTFwmOMR6kKz/+maDnNiUIdUagBolFHFGoAj3WkpMHAmBORbbYZ41Eg4RGR60Rkk4jERKTR1hQRmSAiW0Vkm4j43q9ZRDqJyN9F5NP43wZ/xBCRh+P1fiwij4iPPUKTqKFARBbGa9gsIoV+1ZBMHfFp80Vkp4g8muoaRKRYRFbE34/1InKDT8tu8rMmInki8mL88bKEXn9V9f0C9Af6AkuAcxuZJhv4DOgF5ALrgAE+1/EwcG/8+r3AQw1McwGwPF5PNrACGJfKGuKPLQHGx6+3A9qk+rWoN+3fgDnAoyG8H2cDZ8WvnwbsBjq2cLnNftaAKcCT8euTgRebna+fL04jH4jGwjMSeLfe7fuA+3xe/lage/x6d9zvTQ3VsQZoDbQBPgT6p7iGAUBpwO9Fs3XEHxsGvADcHEB4EqrhuOesqwtTC5bb7GcNeBcYGb/eCvfjrTQ13zD3eU4H6h/juDN+n5+6qmpdf5Svga7HT6CqK4DFuG+43bgX+eNU1oD7tv1BRF4RkY9E5E8i4neHpmbrEJEs4M/A3T4vO+EajqvnPNya4rMWLjeRz9qRaVS1BtgPdG5qpp57GIjIIqBbAw9NVdWUneivqTrq31BVFZGfNS2KSB/cZmaP+F1/F5HRqrosVTXg3ofRwFBgB/Ai7pv/6URr8KmOKcBbqrrT626fDzXUzac7MBu4SVVjnooJmOfwqOolLVz2LuCMerd7xO/zrQ4R+UZEuqvq7vib8W0Dk10NrFTVn+LPeRu3mk84PD7UsBNYq6rb4895DRhBkuHxoY6RwGgRmYLb78oVkZ9UNeHGHB9qQETygTdxX8QrE112ExL5rNVNs1NEWgEdgCaPvQ1zs201cJaInCkiubidtNd9XsbrwE3x6zdRdwbcY+0AxopIKxHJAcYCfm62JVLDaqCjiJwSv30xsNnHGhKqQ1VvVNUCVS3EbbrNSiY4ftQQ/yy8Gl/2Sz4tN5HPWv3argXe1/gOUKMC2jm9GvdtWgl8Q3xnDdd68la96a4APsFt004NoI7OwHvAp8AioFP8/nOB/4tfzwZm4AKzGfhLqmuI3x4PrAc2AM8CuWHUUW/6m/G/wSCR9+M3QDWwtt6l2Idl/+yzBvwnMDF+/SRgHrANWAX0am6e1sPAGI+sh4ExHll4jPHIwmOMRxYeYzyy8BjjkYXHGI8sPMZ4ZOFJAyIyT0QeFZFSEflSRC4Ukdki8omIJNWFx/gnZSOGmhYpAlao6h0icj+uz9s4YA+uL1aeqlaGWWAmsvBEnIicBHQEpsfvUuBpjXftF5FaoCqk8jKabbZF30CgXI92yx8ClAGISA/gK7U+VqGw8ERfEe5oyjqDcR1IwQVp/c+eYVLCwhN9RbiexXWbcK1Vte7sOvWDZFLMelUb45GteYzxyMJjjEcWHmM8svAY45GFxxiPLDzGeGThMcYjC48xHv0/M/jAaB+51sUAAAAASUVORK5CYII=\n", 319 | "text/plain": [ 320 | "
" 321 | ] 322 | }, 323 | "metadata": { 324 | "needs_background": "light" 325 | }, 326 | "output_type": "display_data" 327 | }, 328 | { 329 | "data": { 330 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMsAAACfCAYAAABX5C3SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFJBJREFUeJzt3Xt4FPW9x/H3NzcSAlaBcNSEgBc8qCExCRCKCigWEBAVfVSsRYrWp7Qea22pVc7R1p7aih6KvcGxHLF9sIn4qEhDFfAx4SKBmlAKKVJETUkUi3ITEjabZH/nj9mEALlMkp2ZvXxfz7MPu5vZmU8SPpn7jBhjUEp1Ls7rAEpFCi2LUjZpWZSyScuilE1aFqVs0rIoZZOWRSmbtCxK2aRlUcqmBDcmMmDAADNkyBA3JqVUl1VUVHxujEnrbDhXyjJkyBDKy8vdmJRSXSYi/7QznC6GKWWTlkUpm7QsStnkyjqLaltm5mCqq/e1vB40KJN9+2wtPodcQ0MDNTU1+Hw+T6bvhuTkZDIyMkhMTOzW57UsHqqu3kdx2f6W19O+fJ5nWWpqaujbty9DhgxBRDzL4RRjDAcPHqSmpoYLLrigW+PQxTAFgM/no3///lFZFAARoX///j2ac2pZVItoLUqznn5/WhalbNKyhJHEpF6ICCJCZuZgr+Oo0+gKfhhp8Ne3rPB7ubIfaRobG0lIcP6/ss5ZVNhYuHAhWVlZZGVlsWjRIqqqqsjKymr5+jPPPMOPfvQjAMaPH8+DDz7IiBEjePbZZ3n55ZfJysoiJyeHsWPHOpJP5yyeuoGfPXo2xsDY6054HabFgw/C9u2hHecVV8CiRe1/vaKigmXLlrF161aMMRQUFDBu3LgOx+n3+1uOORw+fDhr1qwhPT2dI0eOhDJ6Cy2LB4yB734XYBW7K5uIi4PNpcnA7wgEIC4G5/ebNm3i5ptvJjU1FYAZM2awcePGDj9z++23tzy/8sormT17NrfddhszZsxwJKOWxQMLF8KzzwI8y/Ov3g4Cf1zah5deuJeiZce4855aT/N1NAdw05EjRwgEAi2vT99H0lwsgCVLlrB161ZWr15Nfn4+FRUV9O/fP6R5YvBvmLfeew8efRSsP34PEp8A8fFw133HgT9Q+H992Ls79v6GXX311axcuZK6ujpqa2t57bXXuP766zlw4AAHDx6kvr6e4uLidj//wQcfUFBQwBNPPEFaWhrV1dUhzxh7vxWPPfQQpKbC4sXw6qsn37f2lz3Al87+Kkv+5yyv4nkmLy+P2bNnM2rUKADuvfdeRo4cyWOPPcaoUaNIT09n2LBh7X5+3rx5vP/++xhjmDBhAjk5OaEPaYxx/JGfn2+UMRs2GAPGLFhgvQZMcdn+lgdg7n/4iLHWaia6mm3Xrl2uTs8rbX2fQLmx8f9YF8NctGABpKXBt7/d/jATpp4g7d+aEPlP3UEZZrQsLtmzB4qLYe5c6N27/eESE2HqrXUYczWLCw9QXLb/lMP4lXe0LC5ZtsxakV+6dGTLHKM9E6acABpYs6qDVinXaVlcEAjA8uUwaRJ88kk5xWX7TzmP5XTn9AsAr/P2n1No8LuXU3VMy+KC0lKoqYFZs7ryqd/xxdE4yjYkO5RKdZWWxWGZmYOZMOEF4Ch33JHShU+uY8DAJtav1bKECy2Lw6qrPyM5ZRYTpydSXPZRFz5pGDPex7atvUhI7N+ynuPW1rHMzMGnTLOnDzuZRYS77rqr5XVjYyNpaWlMmzat5b033niDESNGcNlll5Gbm8v3vvc9R77/tuhOScddh+9EHFdP6PrprGPG+1i1IhW4luKyX7a878bh+6dfH6Cn7GROTU2lsrKSEydOkJKSwrp160hPT2/5emVlJffffz+rV69m2LBhNDU18dxzz4UsY2d0zuK4G+mdGiArt+tr6pdmN3D2OU3ALaGPFaamTJnC6tWrASgsLGTmzJktX1uwYAHz589v2ZMfHx/P3LlzXcumZXFQUxPADYwYU093rr4THw+jx9YDU6mP3isUneKOO+6gqKgIn8/Hjh07KCgoaPlaZWUl+fn5nmXTsjho61aAgRRcXd/tcYy5xgf0Yfu7vUIVK6xlZ2dTVVVFYWEhU6ZM8TrOKbQsDlq1CqCB/NHdL8vwXD9wnIotsVEWgOnTp/P973//lEUwgMsvv5yKigqPUmlZHPWnPwGsp09f0+1xJCYBvE1FWRKm+6OJKHPmzOHxxx9n+PDhp7w/b948nnzySfbs2QNAIBBgyZIlruXSrWEOqa6GXbsA3gCyOhm6M2/yr/3T+aQ6nvTMpp6Hs2HQoMyQbnUbNCjT9rAZGRk88MADZ7yfnZ3NokWLmDlzJnV1dYjIKZuVnaZlcci6dc3P1gLzeji2NwGo2NKL9My6Ho7LHi+uuXz8+PEz3hs/fjzjx49veT1t2jRXC9KaLoY5ZO1aOO88gMoQjO0j0jMbqSiLnfWWcKRlcUBTkzVnmTgxdOPMG13Pzm1JMbMJORxpWRzw17/CoUMhLssoP36/sLsyKXQjPY2J8i0IPf3+tCwOWLvW+ve660I3zsuv8BMXb9hR4UxZkpOTOXjwYNQWxgRvOZGc3P0DU3UF3wFvvw05OTBwYOjG2TvVMPTSBsfKkpGRQU1NDZ999pkj4w8HzTcz6i4tS4jV18Pbb/swZjEiD4V03Nl5fl59MRXoE9LxAiQmJnb7Jj+xQhfDQuwvfwFjkvnPp2aH9KhdgJx8P01NAlwV0vEqe7QsIVZSAhAg64rQnw88LNtPQqIBrg35uFXntCwhVloKsJ0+Z4V+RTk5GYZlNaBl8YaWJYR8PigrAyh1bBo5+fVALocPOzYJ1Q4tSwht3WoVBkocm0b2CD8QR79+N+kF+FymZQmh0tLm20V0fKuEnrjksgaglmm3vqgX4HOZliWESkogNxfgqGPTsM643MTObc7tyVdt63JZRCRVROKdCBPJfD7YsgVaHSDroFL++WEiRw9H9624w02nZRGROBG5U0RWi8gBYDewX0R2icjTInKx8zHDX1mZtUPymmvcmJq1TrTzrzp3cZOdOUsJcBHwCHCuMWaQMWYg1p6xLcBTInJXRyOIBc3rK1e5sr+wguSUADsdOvRFtc3O4S7XGWMaTn/TGHMIeAV4RUS6ce2S6FJaCnl58KUvuTG1Ri7PadA5i8vszFlyReQhEbm29bqKiJwtIn0A2ipTLDlxwlpfcWcRzDI8z8++jxKBEB6tqTpkpyxfA4YBucAfROS/ROQsQIClToaLFJs3g98PTz89pcNbSYTS8Lzmw2k6vv21Ch07ZXkYa6V+EnAuMAF4D/ixzc9HPesQl0ZWvLUs5AdPtufif28gpXcAcHF2FuM6XWcxxtQBC4GFItILGAqkAecDec7GiwzWwZPl9E4d4to04xOsE8LKN493bZqxzs6m45blCmNMvTGm0hhTYox5EfjB6cPEmtpa67B8J48Ha491Ab5L2e/OzCzm2dp0LCL/ISKnXPhJRJKAcSLye+BuR9JFgHfegYYGcPJ4sPZk51vrLevXuz7pmGSnLJOBJqBQRJp3Rn4EvA/MBBYZY15wMGNYKy2FhASAd1yf9oVDG4GjwcVA5TQ76yw+4LfAb4P7UwYAJ4wxR5wOFwlKSmDkSCgrq3V92vEJABsoLb3B9WnHIttbs0TkfeBl4JvABBG5yLFUEeLYMXj3XXf3r5yplD174JNPvMwQG7qy6fd1YB/wKfAVYKeI7BORMhH5X0fShbl33rEuqOdlWRISrMW/9PQ79fwWh3Xl6i4TjTHZzS9EZBlwM/BrICfUwSJBSYl1yPyYMd5laGx8l9S+Aa68ZikPPLLQlVvoxaquzFkOikhLKYwxW4HrjTE1xpjVoY8W/kpKoKAAevf2MoV1cQw9v8V5XZmzfAPrcJe/A9uBS4ETjqSKAEePQkUFzJ/vdRLremJbNybz+QE9oMJJtn+6xpi9WIflv4F12MteYKpDucLepk0QCHi9cm9pPk5M5y7O6tIVKY0xAeDV4COmlZRAUhKMHu11EhhycSN9+gYcu7Srsuh8u5vWrYMrr4SUFK+TWCedZeX62aFzFkdpWbph/37YsQMmTfI6yUk5I/z865ME4EKvo0QtLUs3NN9S4oc/vAIRce0clo7ktdwROYwaHGX0KvrdsGYNwKeseufN4HXC8Hz/xvkZTZyb3sinH0/2NEc00zlLFwUCzTdXXdtSlHAgAvkF9cC11Nd3OrjqhjD6dUeGbdvg888B1ngd5Qx5o/1AH95x/wDomKBl6aI1LR1Z19FgnrDOb/Hz5pteJ4lOWpYuWrPGuuQRhN/t5FJ6G2CjlsUhWpYu+OIL68qT4bTJ+HRx8W+xcyeIpOsRyCGmZemCt96CxsbwLkugqRiABx7dpVfYDzEtSxe8/jr062ftuQ9flfRPa6K8rJfXQaKOlsWmxkYoLoapU5vPuQ9fo66qZ9uWJKD793xXZ9Ky2LRpExw6BDfd5HWSzn15nA/fiTjgOq+jRBUti00rV0KvXjBxotdJOjc8z09qnwDWiawqVLQsNhhjra+IvEXfvuFxLFhHEhNhxJh6YDqNjV6niR5aFhsqKqCqCny+QorL9rt2PeOe+PI4HzCATZu8ThI9tCw2FBY238sxcs55yx/tB+p46SWvk0QPLUsnmpqgqAiuvx4gcq4raO3Nf50VK6zbYaie07J0YuNG6wJ2M2d6naQ7lnPoUOvj2VRPaFk6UVhoXerohgi8QmpC4nrgM6ZPf0kPfQkBLUsH6upgxQpr30pqqtdpuq6xoZapM3qTlHQb1dWHvY4T8bQsHXj5ZThyBL7xDa+TdN+1U3z4/QLc6XWUiKdl6cBzz8Ell8C4CL5t4yWXNXDh0AZgLsZ4nSayaVnaUVlp3Vj1vvusU3YjlQhMmVEH5LB5s9dpIpuWpR2LF1sX0bs7Cu5pNm6iDzjK4sVeJ4lsWpY2HDgAzz8PiYkrSEuTsLncUXdZ+1x+z4oV8PHHXqeJXFqWNvzqV1BfD7W1j7cc3hIJh7h07BcEAvDMM17niFxaltMcOwa/+U3zofi7vY4TMolJ+2lqeoFFi+oQGaj7XbpBy3KaX/4SDh+Ghx/2OkloNfjrWVI0FZEUbv3ah3rKcTdoWVr59FP4+c+tuUpBgddpQi9jcBNjv+Jj1UupQGanw6tTaVlaefxx8Pngqae8TuKcu+ceC24Kj+Jv0iFalqDycli6FL71LWtHZLQaeG6AGXfVAnewYYPXaSKLlgVrbjJrFpx3Hvz4x16ncd4tdx0HPuTrX7c2aCh7tCxY94V87z1r30p29uCI36/SmeRkgFlUVcF3vuNxmAgS82X54x9h4UJr8WviRKiu3hcl+1U6lphUTiDw3yxbBv36zfM6TkSI6bJs3gxz5sDYsfCLX3idxl0N/npWbryX3IJ6Dh/+mZ4gZkPMlmXDBpg8GZqaPmLDhgH06hXdi15tSUiAR356BKjklltO3tFMtS0my1JUZF2vOCMDGhuvorisMiYWvdrSO9WQkHgjtbXbmTSpAZE5DBqke/fbElNlOXQI7rnHOp8+Lw/Wrwf4xOtYnmts2MdL684jZ0QAeJ6amgXBGzap1mKiLEeOWAcQXnwxvPACPPIIlJZCWprXycJHah/DE4sOM+ubx4AZXHQR/OQn1h8YZYnasnz+ObzyirX/JD0d5s2DkSNh+3ZYvnwwSUmxt47Smfh4uO3uWiCHa6+Fxx6D88+35sRFRcT83MaV68EbAw0NJ583n97a+t+uvFdXB8ePW49jx6xfYnW19aiqsu77WFVlfaZfP7jzTmvTcG6u9V7z5mHw/i7D4Sgx6UNWrhRgOI2N36Ko6FaKigYAkJlp/Rwvusj6I5SeDueeC336nHykplobD+LjTz7i4iL7jFNwqSzbtllnHTotJQUGDbLmIIcPP8nRo29w6FAZS5c2sXx5Cj7fCedDRIEGf/0pf0xe33QTe3Yd5OG5T7NvXzb79uUCg4HeXRpvXNzJ8oSqOG4WUIwLVzEQkc+Afzo+IcsAIJIWGDSvs+zkHWyM6XQN1pWyuElEyo0xI7zOYZfmdVYo80btCr5SoaZlUcqmaCzLc14H6CLN66yQ5Y26dRalnBKNcxalHBHxZRGRn4jIDhHZLiJrReT8doZrCg6zXURWuZ2zVQ67ee8WkfeDD8+uiykiT4vI7mDm10Tk7HaGqxKRncHvq9ztnK1y2M07WUT+ISJ7ReSHtkZujInoB3BWq+cPAEvaGe6411nt5gX6AR8G/z0n+Pwcj/JOBBKCz58CnmpnuCpgQBj8fDvNC8QDHwAXAknA34DLOht3xM9ZjDFftHqZCoT1SpjNvJOAdcaYQ8aYw8A6YLIb+U5njFlrjGm+5/EWIMOLHHbZzDsK2GuM+dAY4weKgBs7G3fElwVARH4qItXAV4HH2hksWUTKRWSLiNzkYrwz2MibDlS3el0TfM9rc4A32vmaAdaKSIWI3Odipo60l7dbP9+IKIuIvCUilW08bgQwxsw3xgwCXgTub2c0g421J/dOYJGIXBTmeV3TWd7gMPOBRqzMbbnKGJMHXA98W0TGhnneLnPlQMqeMsZcZ3PQF4E/A4+3MY6Pg/9+KCKlQC7WcmvIhSDvx8D4Vq8zgNIeB2tHZ3lFZDYwDZhgggv9bYyj+ed7QERew1rUceTKZCHI+zEwqNXrjOB7HYqIOUtHRGRoq5c30sbVvEXkHBHpFXw+ALgS2OVOwjOydJoXWANMDOY+B2ul1ZNLSojIZOAHwHRjTF07w6SKSN/m51h5K91LeUqWTvMC7wJDReQCEUkC7gA630Lq9daLEGz9eAXrF7MD+BOQHnx/BLA0+HwMsBNrq8dO4J5wzht8PQfYG3x83cO8e7GW77cHH0uC758P/Dn4/MLgz/ZvwN+B+eGcN/h6CrAHa+nCVl7dg6+UTRG/GKaUW7QsStmkZVHKJi2LUjZpWZSyScuilE1aFqVs0rJEmODe8l+LyGivs8QaLUvk+SaQDFzldZBYo2WJPJOBf2AdyqFcpGWJICKSjHWWXx6w3uM4MUfLElmGYpVltzGmweswsSYizmdRLdKAS7BxCqwKPZ2zRJbzsQ7xjwue56JcpGWJECKSgLWuci6wBGjyNlHs0fNZlLJJ5yxK2aRlUcomLYtSNmlZlLJJy6KUTVoWpWzSsihlk5ZFKZv+H2J9F51toIwcAAAAAElFTkSuQmCC\n", 331 | "text/plain": [ 332 | "
" 333 | ] 334 | }, 335 | "metadata": { 336 | "needs_background": "light" 337 | }, 338 | "output_type": "display_data" 339 | } 340 | ], 341 | "source": [ 342 | "plt.figure(figsize=(3,2))\n", 343 | "# plt.hist(samples[:,0,0].cpu().numpy(), int(np.round(200/(1))), density=True, edgecolor='k', facecolor='#b4c7e7')\n", 344 | "plt.hist(samples[:,0,0].cpu().numpy(), 30, density=True, edgecolor='k', facecolor='#b4c7e7')\n", 345 | "\n", 346 | "plt.plot(*gaussian1d(approx[0][0,0], approx[1][0,0,0]), 'b')\n", 347 | "plt.xlim([-1,-0.0])\n", 348 | "plt.yticks([])\n", 349 | "plt.xlabel('$m$')\n", 350 | "plt.ylabel('$q(m)$')\n", 351 | "plt.show()\n", 352 | "\n", 353 | "plt.figure(figsize=(3,2))\n", 354 | "# plt.hist(samples[:,0,1].cpu().numpy(), int(np.round(200/(3.8-1.9))), density=True, edgecolor='k', facecolor='#b4c7e7', label=\"MC\")\n", 355 | "plt.hist(samples[:,0,1].cpu().numpy(), 30, density=True, edgecolor='k', facecolor='#b4c7e7', label=\"MC\")\n", 356 | "plt.plot(*gaussian1d(approx[0][0,1], approx[1][0,1,1]), 'b', label=\"ours\")\n", 357 | "plt.xlim([-3.8,-1.9])\n", 358 | "plt.yticks([])\n", 359 | "plt.xlabel('$\\ell$')\n", 360 | "plt.ylabel('$q(\\ell)$')\n", 361 | "plt.legend()\n", 362 | "plt.show()" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": null, 368 | "metadata": {}, 369 | "outputs": [], 370 | "source": [] 371 | } 372 | ], 373 | "metadata": { 374 | "kernelspec": { 375 | "display_name": "Python 3", 376 | "language": "python", 377 | "name": "python3" 378 | }, 379 | "language_info": { 380 | "codemirror_mode": { 381 | "name": "ipython", 382 | "version": 3 383 | }, 384 | "file_extension": ".py", 385 | "mimetype": "text/x-python", 386 | "name": "python", 387 | "nbconvert_exporter": "python", 388 | "pygments_lexer": "ipython3", 389 | "version": "3.7.1" 390 | } 391 | }, 392 | "nbformat": 4, 393 | "nbformat_minor": 2 394 | } 395 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.optim as optim 7 | from torch.utils.data import DataLoader 8 | 9 | 10 | import dvi.bayes_utils as bu 11 | from dvi.dataset import ToyDataset 12 | from dvi.bayes_models import MLP, AdaptedMLP 13 | from dvi.loss import GLLLoss 14 | 15 | 16 | def train(epoch, model, criterion, dataloader, optimizer): 17 | model.mlp.train() 18 | 19 | for xs, ys in dataloader: 20 | xs, ys = xs.to(args.device), ys.to(args.device) 21 | 22 | optimizer.zero_grad() 23 | 24 | pred = model(xs) 25 | 26 | kl = model.surprise() 27 | 28 | log_likelihood = criterion(pred, ys) 29 | batch_log_likelihood = torch.mean(log_likelihood) 30 | 31 | lmbd = bu.anneal(epoch) 32 | 33 | loss = lmbd * kl / args.train_size - batch_log_likelihood 34 | 35 | loss.backward() 36 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) 37 | 38 | optimizer.step() 39 | 40 | accuracy = torch.mean(torch.abs(pred.mean[:, 0].reshape(-1) - ys)) 41 | if epoch % 20 == 0: 42 | print("Epoch {}: GLL={:.4f}, KL={:.4f}(anneal:{}) | MAE={:.4f}".format(epoch, batch_log_likelihood.item(), kl.item()/500, lmbd, accuracy)) 43 | 44 | 45 | def main(): 46 | trainset = ToyDataset(data_size=args.train_size, sampling=True) 47 | trainloader = DataLoader(trainset, batch_size=args.train_size, shuffle=True) 48 | 49 | 50 | mlp = MLP(args.x_dim, args.y_dim, args.hidden_dims) 51 | model = AdaptedMLP(mlp, args.adapter, device=args.device) 52 | 53 | criterion = GLLLoss() 54 | 55 | optimizer = optim.Adam(model.parameters(), 56 | lr=args.lr) 57 | 58 | for epoch in range(1, args.epochs+1): 59 | train(epoch, model, criterion, trainloader, optimizer) 60 | torch.save(model, 'temp.pth.tar') 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser(description="Deterministic Variational Inference") 65 | parser.add_argument('--method', type=str, default='bayes', 66 | help="Method: bayes|point") 67 | parser.add_argument('--x-dim', type=int, default=1, 68 | help="input dimension") 69 | parser.add_argument('--y-dim', type=int, default=1, 70 | help="output dimension") 71 | parser.add_argument('--nonlinear', type=str, default='relu', 72 | help="Non-Linearity") 73 | 74 | parser.add_argument('--epochs', type=int, default=3000, 75 | help="Epochs") 76 | parser.add_argument('--lr', type=float, default=1e-3, 77 | help='learning rate') 78 | parser.add_argument('--train-size', type=int, default=500, 79 | help='Train size (Also Training batch data size)') 80 | parser.add_argument('--test-size', type=int, default=100, 81 | help='Test size (Also Testing batch data size)') 82 | 83 | parser.add_argument('--seed', type=int, default=3, 84 | help="Random Seed") 85 | 86 | args = parser.parse_args() 87 | args.prior_type = ["empirical", "wider_he", "wider_he"] 88 | args.hidden_dims = [128, 128] 89 | args.adapter = { 90 | 'in': {"scale": [[1.0]], "shift": [[0.0]]}, 91 | 'out': {"scale": [[1.0, 0.83]], "shift": [[0.0, -3.5]]} 92 | } 93 | 94 | if torch.cuda.is_available(): 95 | args.device = torch.device('cuda') 96 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 97 | else: 98 | args.device = torch.device('cpu') 99 | # args.device = torch.device('cpu') 100 | 101 | np.random.seed(args.seed) 102 | torch.manual_seed(args.seed) 103 | print(args) 104 | 105 | main() 106 | 107 | -------------------------------------------------------------------------------- /test/InvGamma_validation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 16, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import seaborn as sns" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 17, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch\n", 21 | "from distributions import InverseGamma\n", 22 | "import numpy as np" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 25, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "alpha = torch.tensor(3.)\n", 32 | "beta = torch.tensor(0.5)\n", 33 | "p = InverseGamma(alpha, scale=beta)\n", 34 | "alpha = torch.tensor(3.)\n", 35 | "beta = torch.tensor(1.)\n", 36 | "p2 = InverseGamma(alpha, scale=beta)\n", 37 | "alpha = torch.tensor(2.)\n", 38 | "beta = torch.tensor(1.)\n", 39 | "p3 = InverseGamma(alpha, scale=beta)\n", 40 | "alpha = torch.tensor(4.47)\n", 41 | "beta = torch.tensor(5.47)\n", 42 | "p4 = InverseGamma(alpha, scale=beta)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 26, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "result1 = []\n", 52 | "result2 = []\n", 53 | "result3 = []\n", 54 | "result4 = []\n", 55 | "for i in range(10000):\n", 56 | " result1.append(p.rsample().item())\n", 57 | " result2.append(p2.rsample().item())\n", 58 | " result3.append(p3.rsample().item())\n", 59 | " result4.append(p4.rsample().item())" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 27, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "data": { 69 | "text/plain": [ 70 | "" 71 | ] 72 | }, 73 | "execution_count": 27, 74 | "metadata": {}, 75 | "output_type": "execute_result" 76 | }, 77 | { 78 | "data": { 79 | "image/png": "\n", 80 | "text/plain": [ 81 | "
" 82 | ] 83 | }, 84 | "metadata": { 85 | "needs_background": "light" 86 | }, 87 | "output_type": "display_data" 88 | } 89 | ], 90 | "source": [ 91 | "plt.xlim(0, 3)\n", 92 | "sns.distplot(np.array(result4), label=r\"$\\alpha=3, \\beta=4$\")\n", 93 | "# sns.distplot(np.array(result3), label=r\"$\\alpha=2, \\beta=1$\")\n", 94 | "sns.distplot(np.array(result2), label=r\"$\\alpha=3, \\beta=1$\")\n", 95 | "sns.distplot(np.array(result1), label=r\"$\\alpha=3, \\beta=0.5$\")\n", 96 | "plt.legend()" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 51, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "from torch.distributions import Gamma" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 9, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "alpha = torch.tensor(4.4798)\n", 115 | "beta = torch.tensor(1/5.4798)\n", 116 | "q = InverseGamma(alpha, beta)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 10, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "result = []\n", 126 | "for i in range(10000):\n", 127 | " result.append(q.rsample().item())" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 11, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "data": { 137 | "text/plain": [ 138 | "" 139 | ] 140 | }, 141 | "execution_count": 11, 142 | "metadata": {}, 143 | "output_type": "execute_result" 144 | }, 145 | { 146 | "data": { 147 | "image/png": "\n", 148 | "text/plain": [ 149 | "
" 150 | ] 151 | }, 152 | "metadata": { 153 | "needs_background": "light" 154 | }, 155 | "output_type": "display_data" 156 | } 157 | ], 158 | "source": [ 159 | "sns.distplot(np.array(result))" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [] 175 | } 176 | ], 177 | "metadata": { 178 | "kernelspec": { 179 | "display_name": "Python 3", 180 | "language": "python", 181 | "name": "python3" 182 | }, 183 | "language_info": { 184 | "codemirror_mode": { 185 | "name": "ipython", 186 | "version": 3 187 | }, 188 | "file_extension": ".py", 189 | "mimetype": "text/x-python", 190 | "name": "python", 191 | "nbconvert_exporter": "python", 192 | "pygments_lexer": "ipython3", 193 | "version": "3.7.1" 194 | } 195 | }, 196 | "nbformat": 4, 197 | "nbformat_minor": 2 198 | } 199 | -------------------------------------------------------------------------------- /test/check.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from dataset import ToyDataset\n", 10 | "from torch.utils.data import DataLoader\n", 11 | "from bayes_models import MLP, AdaptedMLP\n", 12 | "from loss import GLLLoss\n", 13 | "import torch" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "dataset = ToyDataset()\n", 23 | "dataloader = DataLoader(dataset, batch_size=500)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 3, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "train_x, train_y = list(dataloader)[0]\n", 33 | "# train_x, train_y = train_x.cuda(), train_y.cuda()" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 4, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "prior_type = [\"empirical\", \"wider_he\", \"wider_he\"]\n", 43 | "hidden_dims = [128, 128]\n", 44 | "adapter = {\n", 45 | " 'in': {\"scale\": [[1.0]], \"shift\": [[0.0]]},\n", 46 | " 'out': {\"scale\": [[1.0, 0.83]], \"shift\": [[0.0, -3.5]]}\n", 47 | "}" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 5, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "mlp = MLP(1, 2, prior_type, hidden_dims)\n", 57 | "model = AdaptedMLP(mlp, adapter)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 6, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "text/plain": [ 68 | "MLP(\n", 69 | " (layers): Sequential(\n", 70 | " (0): LinearCertainActivations(\n", 71 | " prior=empirical, in_features=1, out_features=128, bias=True\n", 72 | " (weight): Parameter=>\n", 73 | " \tVariational Distribution: Normal(loc=torch.Size([128, 1]), scale=torch.Size([128, 1]))\n", 74 | " \tPrior: InverseGamma(concentration: 4.479800224304199, scale: 27.39900016784668)\n", 75 | " (bias): Parameter=>\n", 76 | " \tVariational Distribution: Normal(loc=torch.Size([128]), scale=torch.Size([128]))\n", 77 | " \tPrior: InverseGamma(concentration: 4.479800224304199, scale: 27.39900016784668)\n", 78 | " )\n", 79 | " (1): LinearReLU(\n", 80 | " prior=empirical, in_features=128, out_features=128, bias=True\n", 81 | " (weight): Parameter=>\n", 82 | " \tVariational Distribution: Normal(loc=torch.Size([128, 128]), scale=torch.Size([128, 128]))\n", 83 | " \tPrior: InverseGamma(concentration: 4.479800224304199, scale: 0.21405468881130219)\n", 84 | " (bias): Parameter=>\n", 85 | " \tVariational Distribution: Normal(loc=torch.Size([128]), scale=torch.Size([128]))\n", 86 | " \tPrior: InverseGamma(concentration: 4.479800224304199, scale: 0.21405468881130219)\n", 87 | " )\n", 88 | " (2): LinearReLU(\n", 89 | " prior=empirical, in_features=128, out_features=2, bias=True\n", 90 | " (weight): Parameter=>\n", 91 | " \tVariational Distribution: Normal(loc=torch.Size([2, 128]), scale=torch.Size([2, 128]))\n", 92 | " \tPrior: InverseGamma(concentration: 4.479800224304199, scale: 0.21405468881130219)\n", 93 | " (bias): Parameter=>\n", 94 | " \tVariational Distribution: Normal(loc=torch.Size([2]), scale=torch.Size([2]))\n", 95 | " \tPrior: InverseGamma(concentration: 4.479800224304199, scale: 0.21405468881130219)\n", 96 | " )\n", 97 | " )\n", 98 | ")" 99 | ] 100 | }, 101 | "execution_count": 6, 102 | "metadata": {}, 103 | "output_type": "execute_result" 104 | } 105 | ], 106 | "source": [ 107 | "model" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 7, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "pred = model(train_x)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 8, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "log_variance = pred.mean[:, 1].reshape(-1)\n", 126 | "mean = pred.mean[:, 0].reshape(-1)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 9, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "sll = pred.var[:, 1, 1].reshape(-1)\n", 136 | "smm = pred.var[:, 0, 0].reshape(-1)\n", 137 | "sml = pred.var[:, 0, 1].reshape(-1)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 13, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "text/plain": [ 148 | "tensor([[-inf, -inf, -inf, ..., -inf, -inf, -inf],\n", 149 | " [-inf, -inf, -inf, ..., -inf, -inf, -inf],\n", 150 | " [-inf, -inf, -inf, ..., -inf, -inf, -inf],\n", 151 | " ...,\n", 152 | " [-inf, -inf, -inf, ..., -inf, -inf, -inf],\n", 153 | " [-inf, -inf, -inf, ..., -inf, -inf, -inf],\n", 154 | " [-inf, -inf, -inf, ..., -inf, -inf, -inf]], grad_fn=)" 155 | ] 156 | }, 157 | "execution_count": 13, 158 | "metadata": {}, 159 | "output_type": "execute_result" 160 | } 161 | ], 162 | "source": [ 163 | "from loss import GLLLoss\n", 164 | "loss = GLLLoss()\n", 165 | "loss(pred, train_y)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 11, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "%matplotlib inline\n", 175 | "import matplotlib.pyplot as plt" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 12, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "import bayes_utils as bu\n", 185 | "\n", 186 | "x = torch.arange(-10, 10, 1/100)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 42, 192 | "metadata": {}, 193 | "outputs": [ 194 | { 195 | "data": { 196 | "text/plain": [ 197 | "[]" 198 | ] 199 | }, 200 | "execution_count": 42, 201 | "metadata": {}, 202 | "output_type": "execute_result" 203 | }, 204 | { 205 | "data": { 206 | "image/png": "\n", 207 | "text/plain": [ 208 | "
" 209 | ] 210 | }, 211 | "metadata": { 212 | "needs_background": "light" 213 | }, 214 | "output_type": "display_data" 215 | } 216 | ], 217 | "source": [ 218 | "plt.plot(x.numpy(), bu.softrelu(x).numpy())\n", 219 | "plt.plot(x.numpy(), bu.standard_gaussian(x).numpy())\n", 220 | "plt.plot(x.numpy(), bu.gaussian_cdf(x).numpy())\n", 221 | "plt.plot(x.numpy(), (x * bu.gaussian_cdf(x)).numpy())" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 39, 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "data": { 231 | "text/plain": [ 232 | "[]" 233 | ] 234 | }, 235 | "execution_count": 39, 236 | "metadata": {}, 237 | "output_type": "execute_result" 238 | }, 239 | { 240 | "data": { 241 | "image/png": "\n", 242 | "text/plain": [ 243 | "
" 244 | ] 245 | }, 246 | "metadata": { 247 | "needs_background": "light" 248 | }, 249 | "output_type": "display_data" 250 | } 251 | ], 252 | "source": [ 253 | "plt.plot(x.numpy(), bu.standard_gaussian(x).numpy())" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [] 262 | } 263 | ], 264 | "metadata": { 265 | "kernelspec": { 266 | "display_name": "Python 3", 267 | "language": "python", 268 | "name": "python3" 269 | }, 270 | "language_info": { 271 | "codemirror_mode": { 272 | "name": "ipython", 273 | "version": 3 274 | }, 275 | "file_extension": ".py", 276 | "mimetype": "text/x-python", 277 | "name": "python", 278 | "nbconvert_exporter": "python", 279 | "pygments_lexer": "ipython3", 280 | "version": "3.7.1" 281 | } 282 | }, 283 | "nbformat": 4, 284 | "nbformat_minor": 2 285 | } 286 | --------------------------------------------------------------------------------