├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── core ├── __init__.py ├── bayesian_utils.py ├── layers.py ├── logger.py ├── losses.py ├── models.py └── utils.py ├── pics ├── classification │ └── circles │ │ ├── test.png │ │ └── train.png └── regression │ ├── det │ ├── last.png │ └── swapped.png │ └── mcvi │ ├── last.png │ └── swapped.png └── scripts ├── __init__.py ├── fc-variance.py ├── lenet-det.py ├── lenet-variance-first.py ├── lenet-variance-last.py ├── lenet-variance.py ├── lenet-vdo.py ├── toy_data_classification.py └── toy_data_regression.py /.gitignore: -------------------------------------------------------------------------------- 1 | led / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # IDE 107 | .idea/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Alexander Markov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DVI 2 | 3 | Pytorch implementation of https://arxiv.org/pdf/1810.03958.pdf with additional experiments: 4 | 5 | - classification task 6 | - conv nets 7 | - variational dropout 8 | - variance networks 9 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovalexander/DVI/76d1c2261e48d5d804af50b9037c6cd650eb95c2/__init__.py -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovalexander/DVI/76d1c2261e48d5d804af50b9037c6cd650eb95c2/core/__init__.py -------------------------------------------------------------------------------- /core/bayesian_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.distributions import MultivariateNormal 7 | 8 | EPS = 1e-6 9 | 10 | 11 | def matrix_diag_part(tensor): 12 | return torch.diagonal(tensor, dim1=-1, dim2=-2) 13 | 14 | 15 | def standard_gaussian(x): 16 | return torch.exp(-1 / 2 * x * x) / np.sqrt(2 * math.pi) 17 | 18 | 19 | def gaussian_cdf(x): 20 | const = 1 / np.sqrt(2) 21 | return 0.5 * (1 + torch.erf(x * const)) 22 | 23 | 24 | def softrelu(x): 25 | return standard_gaussian(x) + x * gaussian_cdf(x) 26 | 27 | 28 | def heaviside_q(rho, mu1, mu2): 29 | """ 30 | Compute exp ( -Q(rho, mu1, mu2) ) for Heaviside activation 31 | 32 | """ 33 | rho_hat = torch.sqrt(1 - rho * rho) 34 | arcsin = torch.asin(rho) 35 | 36 | rho_s = torch.abs(rho) + EPS 37 | arcsin_s = torch.abs(torch.asin(rho)) + EPS / 2 38 | 39 | A = arcsin / (2 * math.pi) 40 | one_over_coef_sum = (2 * arcsin_s * rho_hat) / rho_s 41 | one_over_coefs_prod = (arcsin_s * rho_hat * (1 + rho_hat)) / (rho * rho) 42 | return A * torch.exp(-( 43 | mu1 * mu1 + mu2 * mu2) / one_over_coef_sum + mu1 * mu2 / one_over_coefs_prod) 44 | 45 | 46 | def relu_q(rho, mu1, mu2): 47 | """ 48 | Compute exp ( -Q(rho, mu1, mu2) ) for ReLU activation 49 | 50 | """ 51 | rho_hat_plus_one = torch.sqrt(1 - rho * rho) + 1 52 | g_r = torch.asin(rho) - rho / rho_hat_plus_one # why minus? why no brackets 53 | 54 | rho_s = torch.abs(rho) + EPS 55 | g_r_s = torch.abs(g_r) + EPS 56 | A = g_r / (2 * math.pi) 57 | 58 | coef_sum = rho_s / (2 * g_r_s * rho_hat_plus_one) 59 | coef_prod = (torch.asin(rho) - rho) / (rho_s * g_r_s) 60 | return A * torch.exp( 61 | - (mu1 * mu1 + mu2 * mu2) * coef_sum + coef_prod * mu1 * mu2) 62 | 63 | 64 | def delta(rho, mu1, mu2): 65 | return gaussian_cdf(mu1) * gaussian_cdf(mu2) + relu_q(rho, mu1, mu2) 66 | 67 | 68 | def compute_linear_var(x_mean, x_var, weights_mean, weights_var, 69 | bias_mean=None, bias_var=None): 70 | x_var_diag = matrix_diag_part(x_var) 71 | xx_mean = x_var_diag + x_mean * x_mean 72 | 73 | term1_diag = torch.matmul(xx_mean, weights_var) 74 | 75 | flat_xCov = torch.reshape(x_var, (-1, weights_mean.size(0))) 76 | 77 | xCov_A = torch.matmul(flat_xCov, weights_mean) 78 | xCov_A = torch.reshape(xCov_A, ( 79 | -1, weights_mean.size(0), weights_mean.size(1))) 80 | 81 | xCov_A = torch.transpose(xCov_A, 1, 2) 82 | xCov_A = torch.reshape(xCov_A, (-1, weights_mean.size(0))) 83 | 84 | A_xCov_A = torch.matmul(xCov_A, weights_mean) 85 | A_xCov_A = torch.reshape(A_xCov_A, ( 86 | -1, weights_mean.size(1), weights_mean.size(1))) 87 | 88 | term2 = A_xCov_A 89 | term2_diag = matrix_diag_part(term2) 90 | 91 | _, n, _ = term2.size() 92 | idx = torch.arange(0, n) 93 | 94 | term3_diag = bias_var if bias_var is not None else 0 95 | result_diag = term1_diag + term2_diag + term3_diag 96 | 97 | result = term2 98 | result[:, idx, idx] = result_diag 99 | return result 100 | 101 | 102 | def compute_heaviside_var(x_var, x_var_diag, mu): 103 | mu1 = torch.unsqueeze(mu, 2) 104 | mu2 = mu1.permute(0, 2, 1) 105 | 106 | s11s22 = torch.unsqueeze(x_var_diag, dim=2) * torch.unsqueeze( 107 | x_var_diag, dim=1) 108 | rho = x_var / torch.sqrt(s11s22) 109 | rho = torch.clamp(rho, -1 / (1 + 1e-6), 1 / (1 + 1e-6)) 110 | return heaviside_q(rho, mu1, mu2) 111 | 112 | 113 | def compute_relu_var(x_var, x_var_diag, mu): 114 | mu1 = torch.unsqueeze(mu, 2) 115 | mu2 = mu1.permute(0, 2, 1) 116 | 117 | s11s22 = torch.unsqueeze(x_var_diag, dim=2) * torch.unsqueeze( 118 | x_var_diag, dim=1) 119 | rho = x_var / (torch.sqrt(s11s22) + EPS) 120 | rho = torch.clamp(rho, -1 / (1 + EPS), 1 / (1 + EPS)) 121 | return x_var * delta(rho, mu1, mu2) 122 | 123 | 124 | def kl_gaussian(p_mean, p_var, q_mean, q_var): 125 | """ 126 | Computes KL (p || q) from p to q, assuming that both p and q have diagonal 127 | gaussian distributions 128 | 129 | :param p_mean: 130 | :param p_var: 131 | :param prior: 132 | :return: 133 | """ 134 | s_q_var = q_var + EPS 135 | entropy = 0.5 * (1 + math.log(2 * math.pi) + torch.log(p_var)) 136 | cross_entropy = 0.5 * (math.log(2 * math.pi) + torch.log(s_q_var) + \ 137 | (p_var + (p_mean - q_mean) ** 2) / s_q_var) 138 | return torch.sum(cross_entropy - entropy) 139 | 140 | 141 | def kl_loguni(log_alpha): 142 | k1, k2, k3 = 0.63576, 1.8732, 1.48695 143 | C = -k1 144 | mdkl = k1 * torch.sigmoid(k2 + k3 * log_alpha) - 0.5 * torch.log1p( 145 | torch.exp(-log_alpha)) + C 146 | kl = -torch.sum(mdkl) 147 | return kl 148 | 149 | 150 | def logsumexp_mean(y, keepdim=True): 151 | """ 152 | Compute 153 | :param y: tuple of (y_mean, y_var) 154 | y_mean dim [batch_size, hid_dim] 155 | y_var dim [batch_size, hid_dim, hid_dim] 156 | :return: 157 | """ 158 | y_mean = y[0] 159 | y_var = y[1] 160 | logsumexp = torch.logsumexp(y_mean, dim=-1, keepdim=keepdim) 161 | p = torch.exp(y_mean - logsumexp) 162 | 163 | pTdiagVar = torch.sum(p * matrix_diag_part(y_var), dim=-1, keepdim=keepdim) 164 | pTVarp = torch.squeeze(torch.matmul(torch.unsqueeze(p, 1), 165 | torch.matmul(y_var, 166 | torch.unsqueeze(p, 2))), 167 | dim=-1) 168 | 169 | return logsumexp + 0.5 * (pTdiagVar - pTVarp) 170 | 171 | 172 | def logsoftmax_mean(y): 173 | """ 174 | Compute 175 | :param y: 176 | :param y: tuple of (y_mean, y_var) 177 | y_mean dim [batch_size, hid_dim] 178 | y_var dim [batch_size, hid_dim, hid_dim] 179 | 180 | """ 181 | return y[0] - logsumexp_mean(y) 182 | 183 | 184 | def sample_activations(x, n_samples): 185 | x_mean, x_var = x[0], x[1] 186 | sampler = MultivariateNormal(loc=x_mean, covariance_matrix=x_var) 187 | samples = sampler.rsample([n_samples]) 188 | return samples 189 | 190 | 191 | def sample_logsoftmax(x, n_samples): 192 | activations = sample_activations(x, n_samples) 193 | logsoftmax = F.log_softmax(activations, dim=1) 194 | return torch.mean(logsoftmax, dim=0) 195 | 196 | 197 | def sample_softmax(x, n_samples): 198 | activations = sample_activations(x, n_samples) 199 | softmax = F.softmax(activations, dim=1) 200 | return torch.mean(softmax, dim=0) 201 | 202 | 203 | def classification_posterior(activations): 204 | mean, var = activations 205 | p = F.softmax(mean, dim=1) 206 | diagVar = matrix_diag_part(var) 207 | pTdiagVar = torch.sum(p * diagVar, dim=-1, keepdim=True) 208 | pTVarp = torch.squeeze(torch.matmul(torch.unsqueeze(p, 1), 209 | torch.matmul(var, 210 | torch.unsqueeze(p, 2))), 211 | dim=-1) 212 | Varp = torch.squeeze(torch.matmul(var, torch.unsqueeze(p, 2)), dim=-1) 213 | 214 | return p * (1 + pTVarp - Varp + 0.5 * diagVar - 0.5 * pTdiagVar) 215 | -------------------------------------------------------------------------------- /core/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.distributions import MultivariateNormal, Independent, Normal 7 | 8 | from .bayesian_utils import kl_gaussian, softrelu, matrix_diag_part, kl_loguni, \ 9 | compute_linear_var, compute_relu_var, standard_gaussian, gaussian_cdf, \ 10 | compute_heaviside_var 11 | 12 | EPS = 1e-6 13 | 14 | 15 | class LinearGaussian(nn.Module): 16 | def __init__(self, in_features, out_features, certain=False, 17 | deterministic=True): 18 | """ 19 | Applies linear transformation y = xA^T + b 20 | 21 | A and b are Gaussian random variables 22 | 23 | :param in_features: input dimension 24 | :param out_features: output dimension 25 | :param certain: if true, than x is equal to its mean and has no variance 26 | """ 27 | 28 | super().__init__() 29 | 30 | self.in_features = in_features 31 | self.out_features = out_features 32 | 33 | self.W = nn.Parameter(torch.Tensor(in_features, out_features)) 34 | self.bias = nn.Parameter(torch.Tensor(out_features)) 35 | 36 | self.W_logvar = nn.Parameter(torch.Tensor(in_features, out_features)) 37 | self.bias_logvar = nn.Parameter(torch.Tensor(out_features)) 38 | 39 | self._initialize_weights() 40 | self._construct_priors() 41 | 42 | self.certain = certain 43 | self.deterministic = deterministic 44 | self.mean_forward = False 45 | self.zero_mean = False 46 | 47 | def _initialize_weights(self): 48 | nn.init.xavier_normal_(self.W) 49 | nn.init.normal_(self.bias) 50 | 51 | nn.init.uniform_(self.W_logvar, a=-10, b=-7) 52 | nn.init.uniform_(self.bias_logvar, a=-10, b=-7) 53 | 54 | def _construct_priors(self): 55 | self.W_mean_prior = nn.Parameter(torch.zeros_like(self.W), 56 | requires_grad=False) 57 | self.W_var_prior = nn.Parameter(torch.ones_like(self.W_logvar) * 0.1, 58 | requires_grad=False) 59 | 60 | self.bias_mean_prior = nn.Parameter(torch.zeros_like(self.bias), 61 | requires_grad=False) 62 | self.bias_var_prior = nn.Parameter( 63 | torch.ones_like(self.bias_logvar) * 0.1, 64 | requires_grad=False) 65 | 66 | def _get_var(self, param): 67 | return torch.exp(param) 68 | 69 | def compute_kl(self): 70 | weights_kl = kl_gaussian(self.W, self._get_var(self.W_logvar), 71 | self.W_mean_prior, self.W_var_prior) 72 | bias_kl = kl_gaussian(self.bias, self._get_var(self.bias_logvar), 73 | self.bias_mean_prior, self.bias_var_prior) 74 | return weights_kl + bias_kl 75 | 76 | def set_flag(self, flag_name, value): 77 | setattr(self, flag_name, value) 78 | for m in self.children(): 79 | if hasattr(m, 'set_flag'): 80 | m.set_flag(flag_name, value) 81 | 82 | def forward(self, x): 83 | """ 84 | Compute expectation and variance after linear transform 85 | y = xA^T + b 86 | 87 | :param x: input, size [batch, in_features] 88 | :return: tuple (y_mean, y_var) for deterministic mode:, shapes: 89 | y_mean: [batch, out_features] 90 | y_var: [batch, out_features, out_features] 91 | 92 | tuple (sample, None) for MCVI mode, 93 | sample : [batch, out_features] - local reparametrization of output 94 | """ 95 | x = self._apply_activation(x) 96 | if self.zero_mean: 97 | return self._zero_mean_forward(x) 98 | elif self.mean_forward: 99 | return self._mean_forward(x) 100 | elif self.deterministic: 101 | return self._det_forward(x) 102 | else: 103 | return self._mcvi_forward(x) 104 | 105 | def _mcvi_forward(self, x): 106 | W_var = self._get_var(self.W_logvar) 107 | bias_var = self._get_var(self.bias_logvar) 108 | 109 | if self.certain: 110 | x_mean = x 111 | x_var = None 112 | else: 113 | x_mean = x[0] 114 | x_var = x[1] 115 | 116 | y_mean = F.linear(x_mean, self.W.t()) + self.bias 117 | 118 | if self.certain or not self.deterministic: 119 | xx = x_mean * x_mean 120 | y_var = torch.diag_embed(F.linear(xx, W_var.t()) + bias_var) 121 | else: 122 | y_var = compute_linear_var(x_mean, x_var, self.W, W_var, self.bias, 123 | bias_var) 124 | 125 | dst = MultivariateNormal(loc=y_mean, covariance_matrix=y_var) 126 | sample = dst.rsample() 127 | return sample, None 128 | 129 | def _det_forward(self, x): 130 | W_var = self._get_var(self.W_logvar) 131 | bias_var = self._get_var(self.bias_logvar) 132 | 133 | if self.certain: 134 | x_mean = x 135 | x_var = None 136 | else: 137 | x_mean = x[0] 138 | x_var = x[1] 139 | 140 | y_mean = F.linear(x_mean, self.W.t()) + self.bias 141 | 142 | if self.certain or x_var is None: 143 | xx = x_mean * x_mean 144 | y_var = torch.diag_embed(F.linear(xx, W_var.t()) + bias_var) 145 | else: 146 | y_var = compute_linear_var(x_mean, x_var, self.W, W_var, self.bias, 147 | bias_var) 148 | 149 | return y_mean, y_var 150 | 151 | def _mean_forward(self, x): 152 | if not isinstance(x, tuple): 153 | x_mean = x 154 | else: 155 | x_mean = x[0] 156 | 157 | y_mean = F.linear(x_mean, self.W.t()) + self.bias 158 | return y_mean, None 159 | 160 | def _zero_mean_forward(self, x): 161 | if not isinstance(x, tuple): 162 | x_mean = x 163 | x_var = None 164 | else: 165 | x_mean = x[0] 166 | x_var = x[1] 167 | 168 | y_mean = F.linear(x_mean, torch.zeros_like(self.W).t()) + self.bias 169 | 170 | W_var = self._get_var(self.W_logvar) 171 | bias_var = self._get_var(self.bias_logvar) 172 | 173 | if x_var is None: 174 | xx = x_mean * x_mean 175 | y_var = torch.diag_embed(F.linear(xx, W_var.t()) + bias_var) 176 | else: 177 | y_var = compute_linear_var(x_mean, x_var, torch.zeros_like(self.W), 178 | W_var, self.bias, bias_var) 179 | 180 | if self.deterministic: 181 | return y_mean, y_var 182 | else: 183 | dst = MultivariateNormal(loc=y_mean, covariance_matrix=y_var) 184 | sample = dst.rsample() 185 | return sample, None 186 | 187 | def _apply_activation(self, x): 188 | return x 189 | 190 | def __repr__(self): 191 | return self.__class__.__name__ + '(' \ 192 | + 'in_features=' + str(self.in_features) \ 193 | + ', out_features=' + str(self.out_features) + ')' 194 | 195 | 196 | class ReluGaussian(LinearGaussian): 197 | def _apply_activation(self, x): 198 | if isinstance(x, tuple): 199 | x_mean = x[0] 200 | x_var = x[1] 201 | else: 202 | x_mean = x 203 | x_var = None 204 | 205 | if x_var is None: 206 | z_mean = F.relu(x_mean) 207 | z_var = None 208 | else: 209 | x_var_diag = matrix_diag_part(x_var) 210 | sqrt_x_var_diag = torch.sqrt(x_var_diag + EPS) 211 | mu = x_mean / (sqrt_x_var_diag + EPS) 212 | 213 | z_mean = sqrt_x_var_diag * softrelu(mu) 214 | z_var = compute_relu_var(x_var, x_var_diag, mu) 215 | 216 | return z_mean, z_var 217 | 218 | 219 | class HeavisideGaussian(LinearGaussian): 220 | def _apply_activation(self, x): 221 | x_mean = x[0] 222 | x_var = x[1] 223 | 224 | if x_var is None: 225 | x_var = x_mean * x_mean 226 | 227 | x_var_diag = matrix_diag_part(x_var) 228 | 229 | sqrt_x_var_diag = torch.sqrt(x_var_diag) 230 | mu = x_mean / (sqrt_x_var_diag + EPS) 231 | 232 | z_mean = gaussian_cdf(mu) 233 | z_var = compute_heaviside_var(x_var, x_var_diag, mu) 234 | 235 | return z_mean, z_var 236 | 237 | 238 | class DeterministicGaussian(LinearGaussian): 239 | def __init__(self, in_features, out_features, certain=False, 240 | deterministic=True): 241 | """ 242 | Applies linear transformation y = xA^T + b 243 | 244 | A and b are Gaussian random variables 245 | 246 | :param in_features: input dimension 247 | :param out_features: output dimension 248 | :param certain: if true, than x is equal to its mean and has no variance 249 | """ 250 | 251 | super().__init__(in_features, out_features, certain, deterministic) 252 | self.W_logvar.requires_grad = False 253 | self.bias_logvar.requires_grad = False 254 | 255 | def compute_kl(self): 256 | return 0 257 | 258 | 259 | class DeterministicReluGaussian(ReluGaussian): 260 | def __init__(self, in_features, out_features, certain=False, 261 | deterministic=True): 262 | """ 263 | Applies linear transformation y = xA^T + b 264 | 265 | A and b are Gaussian random variables 266 | 267 | :param in_features: input dimension 268 | :param out_features: output dimension 269 | :param certain: if true, than x is equal to its mean and has no variance 270 | """ 271 | 272 | super().__init__(in_features, out_features, certain, deterministic) 273 | self.W_logvar.requires_grad = False 274 | self.bias_logvar.requires_grad = False 275 | 276 | def compute_kl(self): 277 | return 0 278 | 279 | 280 | class LinearVDO(nn.Module): 281 | 282 | def __init__(self, in_features, out_features, prior='loguni', 283 | alpha_shape=(1, 1), bias=True, deterministic=True): 284 | super(LinearVDO, self).__init__() 285 | self.in_features = in_features 286 | self.out_features = out_features 287 | self.alpha_shape = alpha_shape 288 | self.W = nn.Parameter(torch.Tensor(out_features, in_features)) 289 | self.log_alpha = nn.Parameter(torch.Tensor(*alpha_shape)) 290 | if bias: 291 | self.bias = nn.Parameter(torch.Tensor(1, out_features)) 292 | else: 293 | self.register_parameter('bias', None) 294 | self.reset_parameters() 295 | self.zero_mean = False 296 | self.permute_sigma = False 297 | self.prior = prior 298 | self.kl_fun = kl_loguni 299 | self.deterministic = deterministic 300 | 301 | def reset_parameters(self): 302 | stdv = 1. / math.sqrt(self.W.size(1)) 303 | self.W.data.uniform_(-stdv, stdv) 304 | self.log_alpha.data.fill_(-5.0) 305 | if self.bias is not None: 306 | self.bias.data.zero_() 307 | 308 | def forward(self, x): 309 | if self.deterministic: 310 | return self._det_forward(x) 311 | else: 312 | return self._mc_forward(x) 313 | 314 | def _mc_forward(self, x): 315 | if isinstance(x, tuple): 316 | x_mean = x[0] 317 | x_var = x[1] 318 | else: 319 | x_mean = x 320 | 321 | if self.zero_mean: 322 | lrt_mean = 0.0 323 | else: 324 | lrt_mean = F.linear(x_mean, self.W) 325 | if self.bias is not None: 326 | lrt_mean = lrt_mean + self.bias 327 | 328 | sigma2 = torch.exp(self.log_alpha) * self.W * self.W 329 | if self.permute_sigma: 330 | sigma2 = sigma2.view(-1)[torch.randperm( 331 | self.in_features * self.out_features).cuda()].view( 332 | self.out_features, self.in_features) 333 | 334 | if x_var is None: 335 | x_var = torch.diag_embed(x_mean * x_mean) 336 | 337 | lrt_cov = compute_linear_var(x_mean, x_var, self.W.t(), sigma2.t()) 338 | dst = MultivariateNormal(lrt_mean, covariance_matrix=lrt_cov) 339 | return dst.rsample(), None 340 | 341 | def compute_kl(self): 342 | return self.W.nelement() * self.kl_fun( 343 | self.log_alpha) / self.log_alpha.nelement() 344 | 345 | def __repr__(self): 346 | return self.__class__.__name__ + '(' \ 347 | + 'in_features=' + str(self.in_features) \ 348 | + ', out_features=' + str(self.out_features) \ 349 | + ', alpha_shape=' + str(self.alpha_shape) \ 350 | + ', prior=' + self.prior \ 351 | + ', bias=' + str(self.bias is not None) + ')' ', bias=' + str( 352 | self.bias is not None) + ')' 353 | 354 | def _det_forward(self, x): 355 | if isinstance(x, tuple): 356 | x_mean = x[0] 357 | x_var = x[1] 358 | else: 359 | x_mean = x 360 | x_var = torch.diag_embed(x_mean * x_mean) 361 | 362 | batch_size = x_mean.size(0) 363 | sigma2 = torch.exp(self.log_alpha) * self.W * self.W 364 | if self.zero_mean: 365 | y_mean = torch.zeros(batch_size, self.out_features).to( 366 | x_mean.device) 367 | else: 368 | y_mean = F.linear(x_mean, self.W) 369 | if self.bias is not None: 370 | y_mean = y_mean + self.bias 371 | 372 | y_var = compute_linear_var(x_mean, x_var, self.W.t(), sigma2.t()) 373 | return y_mean, y_var 374 | 375 | def set_flag(self, flag_name, value): 376 | setattr(self, flag_name, value) 377 | for m in self.children(): 378 | if hasattr(m, 'set_flag'): 379 | m.set_flag(flag_name, value) 380 | 381 | 382 | class ReluVDO(LinearVDO): 383 | def forward(self, x): 384 | x = self._apply_activation(x) 385 | return super().forward(x) 386 | 387 | def _apply_activation(self, x): 388 | if isinstance(x, tuple): 389 | x_mean = x[0] 390 | x_var = x[1] 391 | else: 392 | x_mean = x 393 | x_var = None 394 | 395 | if x_var is None: 396 | z_mean = F.relu(x_mean) 397 | z_var = None 398 | else: 399 | x_var_diag = matrix_diag_part(x_var) 400 | sqrt_x_var_diag = torch.sqrt(x_var_diag + EPS) 401 | mu = x_mean / (sqrt_x_var_diag + EPS) 402 | 403 | z_mean = sqrt_x_var_diag * softrelu(mu) 404 | z_var = compute_relu_var(x_var, x_var_diag, mu) 405 | 406 | return z_mean, z_var 407 | 408 | 409 | class HeavisideVDO(LinearVDO): 410 | def forward(self, x): 411 | x = self._apply_activation(x) 412 | return super().forward(x) 413 | 414 | def _apply_activation(self, x): 415 | x_mean = x[0] 416 | x_var = x[1] 417 | 418 | if x_var is None: 419 | x_var = x_mean * x_mean 420 | 421 | x_var_diag = matrix_diag_part(x_var) 422 | 423 | sqrt_x_var_diag = torch.sqrt(x_var_diag) 424 | mu = x_mean / (sqrt_x_var_diag + EPS) 425 | 426 | z_mean = gaussian_cdf(mu) 427 | z_var = compute_heaviside_var(x_var, x_var_diag, mu) 428 | 429 | return z_mean, z_var 430 | 431 | 432 | class VarianceGaussian(LinearGaussian): 433 | def __init__(self, in_features, out_features, 434 | certain=False, deterministic=True, sigma_sq=False): 435 | super().__init__(in_features, out_features, certain, deterministic) 436 | self.W.data.fill_(0) 437 | self.W.requires_grad = False 438 | self.sigma_sq = sigma_sq 439 | if sigma_sq: 440 | self.W_logvar.data.uniform_(-1 / (in_features + out_features), 441 | 1 / (in_features + out_features)) 442 | self.bias_logvar.data.uniform_(-1 / out_features, 1 / out_features) 443 | 444 | def _zero_mean_forward(self, x): 445 | if self.deterministic: 446 | return self._det_forward(x) 447 | else: 448 | return self._mcvi_forward(x) 449 | 450 | def _get_var(self, param): 451 | if self.sigma_sq: 452 | return param * param 453 | else: 454 | return torch.exp(param) 455 | 456 | def compute_kl(self): 457 | return 0 458 | 459 | 460 | class VarianceReluGaussian(ReluGaussian): 461 | def __init__(self, in_features, out_features, 462 | certain=False, deterministic=True, sigma_sq=False): 463 | super().__init__(in_features, out_features, certain, deterministic) 464 | self.W.data.fill_(0) 465 | self.W.requires_grad = False 466 | self.sigma_sq = sigma_sq 467 | if sigma_sq: 468 | self.W_logvar.data.uniform_(-1 / (in_features + out_features), 469 | 1 / (in_features + out_features)) 470 | self.bias_logvar.data.uniform_(-1 / out_features, 1 / out_features) 471 | 472 | def _get_var(self, param): 473 | if self.sigma_sq: 474 | return param * param 475 | else: 476 | return torch.exp(param) 477 | 478 | def _zero_mean_forward(self, x): 479 | if self.deterministic: 480 | return self._det_forward(x) 481 | else: 482 | return self._mcvi_forward(x) 483 | 484 | def compute_kl(self): 485 | return 0 486 | 487 | 488 | class MeanFieldConv2d(nn.Module): 489 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 490 | activation='relu', padding=0, certain=False, 491 | deterministic=True): 492 | super().__init__() 493 | 494 | self.in_channels = in_channels 495 | self.out_channels = out_channels 496 | self.stride = stride 497 | self.padding = padding 498 | self.activation = activation.strip().lower() 499 | 500 | if not isinstance(kernel_size, tuple): 501 | kernel_size = (kernel_size, kernel_size) 502 | self.kernel_size = kernel_size 503 | 504 | self.W = nn.Parameter( 505 | torch.Tensor(out_channels, in_channels, *self.kernel_size)) 506 | self.W_logvar = nn.Parameter( 507 | torch.Tensor(out_channels, in_channels, *self.kernel_size)) 508 | 509 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 510 | self.bias_logvar = nn.Parameter(torch.Tensor(out_channels)) 511 | 512 | self._initialize_weights() 513 | self._construct_priors() 514 | 515 | self.certain = certain 516 | self.deterministic = deterministic 517 | self.mean_forward = False 518 | self.zero_mean = False 519 | 520 | def _initialize_weights(self): 521 | nn.init.xavier_normal_(self.W) 522 | nn.init.normal_(self.bias) 523 | 524 | nn.init.uniform_(self.W_logvar, a=-10, b=-7) 525 | nn.init.uniform_(self.bias_logvar, a=-10, b=-7) 526 | 527 | def _get_var(self, param): 528 | return torch.exp(param) 529 | 530 | def _construct_priors(self): 531 | self.W_mean_prior = nn.Parameter(torch.zeros_like(self.W), 532 | requires_grad=False) 533 | self.W_var_prior = nn.Parameter(torch.ones_like(self.W_logvar) * 0.1, 534 | requires_grad=False) 535 | 536 | self.bias_mean_prior = nn.Parameter(torch.zeros_like(self.bias), 537 | requires_grad=False) 538 | self.bias_var_prior = nn.Parameter( 539 | torch.ones_like(self.bias_logvar) * 0.1, 540 | requires_grad=False) 541 | 542 | def compute_kl(self): 543 | weights_kl = kl_gaussian(self.W, self._get_var(self.W_logvar), 544 | self.W_mean_prior, self.W_var_prior) 545 | bias_kl = kl_gaussian(self.bias, self._get_var(self.bias_logvar), 546 | self.bias_mean_prior, self.bias_var_prior) 547 | return weights_kl + bias_kl 548 | 549 | def set_flag(self, flag_name, value): 550 | setattr(self, flag_name, value) 551 | for m in self.children(): 552 | if hasattr(m, 'set_flag'): 553 | m.set_flag(flag_name, value) 554 | 555 | def forward(self, x): 556 | x = self._apply_activation(x) 557 | if self.zero_mean: 558 | return self._zero_mean_forward(x) 559 | elif self.mean_forward: 560 | return self._mean_forward(x) 561 | elif self.deterministic: 562 | return self._det_forward(x) 563 | else: 564 | return self._mcvi_forward(x) 565 | 566 | def _zero_mean_forward(self, x): 567 | if self.certain or not self.deterministic: 568 | x_mean = x if not isinstance(x, tuple) else x[0] 569 | x_var = x_mean * x_mean 570 | else: 571 | x_mean = x[0] 572 | x_var = x[1] 573 | 574 | W_var = self._get_var(self.W_logvar) 575 | bias_var = self._get_var(self.bias_logvar) 576 | 577 | z_mean = F.conv2d(x_mean, torch.zeros_like(self.W), self.bias, 578 | self.stride, 579 | self.padding) 580 | z_var = F.conv2d(x_var, W_var, bias_var, self.stride, 581 | self.padding) 582 | 583 | if self.deterministic: 584 | return z_mean, z_var 585 | else: 586 | dst = Independent(Normal(z_mean, z_var), 1) 587 | sample = dst.rsample() 588 | return sample, None 589 | 590 | def _mean_forward(self, x): 591 | if not isinstance(x, tuple): 592 | x_mean = x 593 | else: 594 | x_mean = x[0] 595 | 596 | z_mean = F.conv2d(x_mean, self.W, self.bias, 597 | self.stride, 598 | self.padding) 599 | return z_mean, None 600 | 601 | def _det_forward(self, x): 602 | if self.certain and isinstance(x, tuple): 603 | x_mean = x[0] 604 | x_var = x_mean * x_mean 605 | elif not self.certain: 606 | x_mean = x[0] 607 | x_var = x[1] 608 | else: 609 | x_mean = x 610 | x_var = x_mean * x_mean 611 | 612 | W_var = self._get_var(self.W_logvar) 613 | bias_var = self._get_var(self.bias_logvar) 614 | 615 | z_mean = F.conv2d(x_mean, self.W, self.bias, 616 | self.stride, 617 | self.padding) 618 | z_var = F.conv2d(x_var, W_var, bias_var, self.stride, 619 | self.padding) 620 | return z_mean, z_var 621 | 622 | def _mcvi_forward(self, x): 623 | if self.certain or not self.deterministic: 624 | x_mean = x if not isinstance(x, tuple) else x[0] 625 | x_var = x_mean * x_mean 626 | else: 627 | x_mean = x[0] 628 | x_var = x[1] 629 | 630 | W_var = self._get_var(self.W_logvar) 631 | bias_var = self._get_var(self.bias_logvar) 632 | 633 | z_mean = F.conv2d(x_mean, self.W, self.bias, 634 | self.stride, 635 | self.padding) 636 | z_var = F.conv2d(x_var, W_var, bias_var, self.stride, 637 | self.padding) 638 | 639 | dst = Independent(Normal(z_mean, z_var), 1) 640 | sample = dst.rsample() 641 | return sample, None 642 | 643 | def _apply_activation(self, x): 644 | if self.activation == 'relu' and not self.certain: 645 | x_mean, x_var = x 646 | if x_var is None: 647 | x_var = x_mean * x_mean 648 | 649 | sqrt_x_var = torch.sqrt(x_var + EPS) 650 | mu = x_mean / sqrt_x_var 651 | z_mean = sqrt_x_var * softrelu(mu) 652 | z_var = x_var * (mu * standard_gaussian(mu) + ( 653 | 1 + mu ** 2) * gaussian_cdf(mu)) 654 | return z_mean, z_var 655 | else: 656 | return x 657 | 658 | def set_flag(self, flag_name, value): 659 | setattr(self, flag_name, value) 660 | for m in self.children(): 661 | if hasattr(m, 'set_flag'): 662 | m.set_flag(flag_name, value) 663 | 664 | def __repr__(self): 665 | return self.__class__.__name__ + '(' \ 666 | + 'in_channels=' + str(self.in_channels) \ 667 | + ', out_channels=' + str(self.out_channels) \ 668 | + ', kernel_size=' + str(self.kernel_size) \ 669 | + ', stride=' + str(self.stride) \ 670 | + ', padding=' + str(self.padding) \ 671 | + ', activation=' + str(self.activation) + ')' 672 | 673 | 674 | class VarianceMeanFieldConv2d(MeanFieldConv2d): 675 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 676 | activation='relu', padding=0, certain=False, 677 | deterministic=True, sigma_sq=False): 678 | super().__init__(in_channels, out_channels, kernel_size, stride, 679 | activation, padding, certain, deterministic) 680 | self.W.data.fill_(0) 681 | self.W.requires_grad = False 682 | self.sigma_sq = sigma_sq 683 | if sigma_sq: 684 | self.W_logvar.data.uniform_(-1 / (in_channels + out_channels), 685 | 1 / (in_channels + out_channels)) 686 | self.bias_logvar.data.uniform_(-1 / out_channels, 1 / out_channels) 687 | 688 | def _get_var(self, param): 689 | if self.sigma_sq: 690 | return param * param 691 | else: 692 | return torch.exp(param) 693 | 694 | def compute_kl(self): 695 | return 0 696 | 697 | def _zero_mean_forward(self, x): 698 | if self.deterministic: 699 | return self._det_forward(x) 700 | else: 701 | return self._mcvi_forward(x) 702 | 703 | 704 | class MeanFieldConv2dVDO(nn.Module): 705 | 706 | def __init__(self, in_channels, out_channels, kernel_size, alpha_shape, 707 | certain=False, activation='relu', deterministic=True, stride=1, 708 | padding=0, dilation=1, prior='loguni', bias=True): 709 | super().__init__() 710 | self.in_channels = in_channels 711 | self.out_channels = out_channels 712 | self.kernel_size = (kernel_size, kernel_size) 713 | self.stride = stride 714 | self.padding = padding 715 | self.activation = activation 716 | self.dilation = dilation 717 | self.alpha_shape = alpha_shape 718 | self.groups = 1 719 | self.weight = nn.Parameter( 720 | torch.Tensor(out_channels, in_channels, *self.kernel_size)) 721 | if bias: 722 | self.bias = nn.Parameter(torch.Tensor(1, out_channels, 1, 1)) 723 | else: 724 | self.register_parameter('bias', None) 725 | self.op_bias = lambda input, kernel: F.conv2d(input, kernel, 726 | self.bias.flatten(), 727 | self.stride, self.padding, 728 | self.dilation, 729 | self.groups) 730 | self.op_nobias = lambda input, kernel: F.conv2d(input, kernel, None, 731 | self.stride, 732 | self.padding, 733 | self.dilation, 734 | self.groups) 735 | self.log_alpha = nn.Parameter(torch.Tensor(*alpha_shape)) 736 | self.reset_parameters() 737 | 738 | self.certain = certain 739 | self.deterministic = deterministic 740 | self.mean_forward = False 741 | self.zero_mean = False 742 | self.permute_sigma = False 743 | 744 | self.prior = prior 745 | self.kl_fun = kl_loguni 746 | 747 | def reset_parameters(self): 748 | n = self.in_channels 749 | for k in self.kernel_size: 750 | n *= k 751 | stdv = 1. / math.sqrt(n) 752 | self.weight.data.uniform_(-stdv, stdv) 753 | if self.bias is not None: 754 | self.bias.data.uniform_(-stdv, stdv) 755 | self.log_alpha.data.fill_(-5.0) 756 | 757 | def forward(self, x): 758 | x = self._apply_activation(x) 759 | if self.zero_mean: 760 | return self._zero_mean_forward(x) 761 | elif self.mean_forward: 762 | return self._mean_forward(x) 763 | elif self.deterministic: 764 | return self._det_forward(x) 765 | else: 766 | return self._mcvi_forward(x) 767 | 768 | def _apply_activation(self, x): 769 | if self.activation == 'relu' and not self.certain: 770 | x_mean, x_var = x 771 | if x_var is None: 772 | x_var = x_mean * x_mean 773 | 774 | sqrt_x_var = torch.sqrt(x_var + EPS) 775 | mu = x_mean / sqrt_x_var 776 | z_mean = sqrt_x_var * softrelu(mu) 777 | z_var = x_var * (mu * standard_gaussian(mu) + ( 778 | 1 + mu ** 2) * gaussian_cdf(mu)) 779 | return z_mean, z_var 780 | else: 781 | return x 782 | 783 | def _zero_mean_forward(self, x): 784 | if self.certain or not self.deterministic: 785 | x_mean = x if not isinstance(x, tuple) else x[0] 786 | x_var = x_mean * x_mean 787 | else: 788 | x_mean = x[0] 789 | x_var = x[1] 790 | 791 | W_var = torch.exp(self.log_alpha) * self.weight * self.weight 792 | 793 | z_mean = F.conv2d(x_mean, torch.zeros_like(self.weight), self.bias, 794 | self.stride, 795 | self.padding) 796 | z_var = F.conv2d(x_var, W_var, bias=None, stride=self.stride, 797 | padding=self.padding) 798 | 799 | if self.deterministic: 800 | return z_mean, z_var 801 | else: 802 | dst = Independent(Normal(z_mean, z_var), 1) 803 | sample = dst.rsample() 804 | return sample, None 805 | 806 | def _mean_forward(self, x): 807 | if not isinstance(x, tuple): 808 | x_mean = x 809 | else: 810 | x_mean = x[0] 811 | 812 | z_mean = F.conv2d(x_mean, self.weight, self.bias, 813 | self.stride, 814 | self.padding) 815 | return z_mean, None 816 | 817 | def _det_forward(self, x): 818 | if self.certain and isinstance(x, tuple): 819 | x_mean = x[0] 820 | x_var = x_mean * x_mean 821 | elif not self.certain: 822 | x_mean = x[0] 823 | x_var = x[1] 824 | else: 825 | x_mean = x 826 | x_var = x_mean * x_mean 827 | 828 | W_var = torch.exp(self.log_alpha) * self.weight * self.weight 829 | 830 | z_mean = F.conv2d(x_mean, self.weight, self.bias.flatten(), 831 | self.stride, 832 | self.padding) 833 | z_var = F.conv2d(x_var, W_var, bias=None, stride=self.stride, 834 | padding=self.padding) 835 | return z_mean, z_var 836 | 837 | def _mcvi_forward(self, x): 838 | if isinstance(x, tuple): 839 | x_mean = x[0] 840 | x_var = x[1] 841 | else: 842 | x_mean = x 843 | x_var = x_mean * x_mean 844 | 845 | if self.zero_mean: 846 | lrt_mean = self.op_bias(x_mean, 0.0 * self.weight) 847 | else: 848 | lrt_mean = self.op_bias(x_mean, self.weight) 849 | 850 | sigma2 = torch.exp(self.log_alpha) * self.weight * self.weight 851 | if self.permute_sigma: 852 | sigma2 = sigma2.view(-1)[ 853 | torch.randperm(self.weight.nelement()).cuda()].view( 854 | self.weight.shape) 855 | 856 | lrt_std = torch.sqrt(1e-16 + self.op_nobias(x_var, sigma2)) 857 | if self.training: 858 | eps = lrt_std.data.new(lrt_std.size()).normal_() 859 | else: 860 | eps = 0.0 861 | return lrt_mean + lrt_std * eps, None 862 | 863 | def compute_kl(self): 864 | return self.weight.nelement() / self.log_alpha.nelement() * kl_loguni( 865 | self.log_alpha) 866 | 867 | def __repr__(self): 868 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' 869 | ', stride={stride}') 870 | s += ', padding={padding}' 871 | s += ', alpha_shape=' + str(self.alpha_shape) 872 | s += ', prior=' + self.prior 873 | s += ', dilation={dilation}' 874 | if self.bias is None: 875 | s += ', bias=False' 876 | s += ')' 877 | return s.format(name=self.__class__.__name__, **self.__dict__) 878 | 879 | def set_flag(self, flag_name, value): 880 | setattr(self, flag_name, value) 881 | for m in self.children(): 882 | if hasattr(m, 'set_flag'): 883 | m.set_flag(flag_name, value) 884 | 885 | 886 | class AveragePoolGaussian(nn.Module): 887 | def __init__(self, kernel_size, stride=None, padding=0): 888 | super().__init__() 889 | 890 | if not isinstance(kernel_size, tuple): 891 | kernel_size = (kernel_size, kernel_size) 892 | 893 | self.kernel_size = kernel_size 894 | self.stride = stride 895 | self.padding = padding 896 | 897 | def forward(self, x): 898 | if not isinstance(x, tuple): 899 | raise ValueError( 900 | "Input for pooling layer should be tuple of tensors") 901 | 902 | x_mean, x_var = x 903 | z_mean = F.avg_pool2d(x_mean, self.kernel_size, self.stride, 904 | self.padding) 905 | if x_var is None: 906 | z_var = None 907 | else: 908 | n = self.kernel_size[0] * self.kernel_size[1] 909 | z_var = F.avg_pool2d(x_var, self.kernel_size, self.stride, 910 | self.padding) / n 911 | return z_mean, z_var 912 | 913 | def __repr__(self): 914 | return self.__class__.__name__ + '(' \ 915 | + 'kernel_size= ' + str(self.kernel_size) \ 916 | + ', stride=' + str(self.stride) \ 917 | + ', padding=' + str(self.padding) + ')' 918 | 919 | def set_flag(self, flag_name, value): 920 | setattr(self, flag_name, value) 921 | for m in self.children(): 922 | if hasattr(m, 'set_flag'): 923 | m.set_flag(flag_name, value) 924 | -------------------------------------------------------------------------------- /core/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import sys 4 | from collections import OrderedDict 5 | from time import gmtime, strftime 6 | 7 | import numpy as np 8 | from pandas import DataFrame 9 | from tabulate import tabulate 10 | 11 | 12 | class Logger: 13 | def __init__(self, name='name', fmt=None): 14 | self.handler = True 15 | self.scalar_metrics = OrderedDict() 16 | self.fmt = fmt if fmt else dict() 17 | 18 | base = './logs' 19 | if not os.path.exists(base): os.mkdir(base) 20 | 21 | time = gmtime() 22 | hash = ''.join([chr(random.randint(97, 122)) for _ in range(5)]) 23 | fname = '-'.join(sys.argv[0].split('/')[-3:]) 24 | self.path = '%s/%s-%s-%s-%s' % ( 25 | base, fname, name, strftime('%m-%d-%H-%M', time), hash) 26 | 27 | self.logs = self.path + '.csv' 28 | self.output = self.path + '.out' 29 | self.checkpoint = self.path + '.cpt' 30 | 31 | def prin(*args): 32 | str_to_write = ' '.join(map(str, args)) 33 | with open(self.output, 'a') as f: 34 | f.write(str_to_write + '\n') 35 | f.flush() 36 | 37 | print(str_to_write) 38 | sys.stdout.flush() 39 | 40 | self.print = prin 41 | 42 | def add_scalar(self, t, key, value): 43 | if key not in self.scalar_metrics: 44 | self.scalar_metrics[key] = [] 45 | self.scalar_metrics[key] += [(t, value)] 46 | 47 | def add_dict(self, t, d): 48 | for key, value in d.iteritems(): 49 | self.add_scalar(t, key, value) 50 | 51 | def add(self, t, **args): 52 | for key, value in args.items(): 53 | self.add_scalar(t, key, value) 54 | 55 | def iter_info(self, order=None): 56 | names = list(self.scalar_metrics.keys()) 57 | if order: 58 | names = order 59 | values = [self.scalar_metrics[name][-1][1] for name in names] 60 | t = int(np.max([self.scalar_metrics[name][-1][0] for name in names])) 61 | fmt = ['%s'] + [self.fmt[name] if name in self.fmt else '.1f' for name 62 | in names] 63 | 64 | if self.handler: 65 | self.handler = False 66 | self.print( 67 | tabulate([[t] + values], ['epoch'] + names, floatfmt=fmt)) 68 | else: 69 | self.print( 70 | tabulate([[t] + values], ['epoch'] + names, tablefmt='plain', 71 | floatfmt=fmt).split('\n')[1]) 72 | 73 | def save(self, silent=False): 74 | result = None 75 | for key in self.scalar_metrics.keys(): 76 | if result is None: 77 | result = DataFrame(self.scalar_metrics[key], 78 | columns=['t', key]).set_index('t') 79 | else: 80 | df = DataFrame(self.scalar_metrics[key], 81 | columns=['t', key]).set_index('t') 82 | result = result.join(df, how='outer') 83 | result.to_csv(self.logs) 84 | if not silent: 85 | self.print( 86 | 'The log/output/model have been saved to: ' + self.path + ' + .csv/.out/.cpt') 87 | -------------------------------------------------------------------------------- /core/losses.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from numpy import clip 7 | 8 | from .bayesian_utils import logsoftmax_mean 9 | from .utils import classification_posterior 10 | 11 | EPS = 1e-8 12 | 13 | 14 | class RegressionLoss(nn.Module): 15 | def __init__(self, net, args): 16 | """ 17 | Compute ELBO for regression task 18 | 19 | :param net: neural network 20 | :param method: 21 | :param use_heteroskedastic: 22 | :param homo_log_var_scale: 23 | """ 24 | super().__init__() 25 | 26 | self.net = net 27 | self.method = args.method 28 | self.use_het = args.heteroskedastic 29 | self.det = not args.mcvi 30 | self.homo_log_var_scale = torch.FloatTensor( 31 | [args.homo_log_var_scale]).to( 32 | device=args.device) 33 | if not self.use_het and self.homo_log_var_scale is None: 34 | raise ValueError( 35 | "homo_log_var_scale must be set in homoskedastic mode") 36 | self.warmup = args.warmup_updates 37 | self.anneal = args.anneal_updates 38 | self.batch_size = args.batch_size 39 | 40 | def gaussian_likelihood_core(self, target, mean, log_var, smm, sml, sll): 41 | const = math.log(2 * math.pi) 42 | exp = torch.exp(-log_var + 0.5 * (sll + EPS)) 43 | return -0.5 * ( 44 | const + log_var + exp * (smm + (mean - sml - target) ** 2)) 45 | 46 | def heteroskedastic_gaussian_loglikelihood(self, pred_mean, pred_var, 47 | target): 48 | log_var = pred_mean[:, 1].view(-1) 49 | mean = pred_mean[:, 0].view(-1) 50 | 51 | if self.method.lower() == 'bayes': 52 | sll = pred_var[:, 1, 1].view(-1) 53 | smm = pred_var[:, 0, 0].view(-1) 54 | sml = pred_var[:, 0, 1].view(-1) 55 | else: 56 | sll = smm = sml = 0 57 | return self.gaussian_likelihood_core(target, mean, log_var, smm, sml, 58 | sll) 59 | 60 | def homoskedastic_gaussian_loglikelihood(self, pred_mean, pred_var, target): 61 | log_var = self.homo_log_var_scale 62 | if self.det: 63 | mean = pred_mean[:, 0].view(-1) 64 | else: 65 | mean = pred_mean.view(-1) 66 | 67 | sll = sml = 0 68 | if self.method.lower() == 'bayes' and self.det: 69 | smm = pred_var[:, 0, 0].view(-1) 70 | else: 71 | smm = 0 72 | return self.gaussian_likelihood_core(target, mean, log_var, smm, sml, 73 | sll) 74 | 75 | def forward(self, pred, target, step): 76 | pred_mean = pred[0] 77 | pred_var = pred[1] 78 | 79 | assert not target.requires_grad 80 | kl = 0.0 81 | for module in self.net.children(): 82 | if hasattr(module, 'compute_kl'): 83 | kl = kl + module.compute_kl() 84 | if hasattr(self.net, 'compute_kl'): 85 | kl = kl + self.net.compute_kl() 86 | 87 | gaussian_likelihood = self.heteroskedastic_gaussian_loglikelihood if self.use_het \ 88 | else self.homoskedastic_gaussian_loglikelihood 89 | 90 | log_likelihood = gaussian_likelihood(pred_mean, pred_var, target) 91 | batched_likelihood = torch.mean(log_likelihood) 92 | 93 | lmbda = clip((step - self.warmup) / self.anneal, 0, 1) 94 | 95 | loss = lmbda * kl / self.batch_size - batched_likelihood 96 | return loss, batched_likelihood, kl / self.batch_size 97 | 98 | 99 | class ClassificationLoss(nn.Module): 100 | def __init__(self, net, args): 101 | super().__init__() 102 | 103 | self.net = net 104 | self.warmup = args.warmup_updates 105 | self.anneal = args.anneal_updates 106 | self.data_size = args.data_size 107 | self.mcvi = args.mcvi 108 | self.use_samples = args.use_samples 109 | self.n_samples = args.mc_samples 110 | self._step = 0 111 | self.log_mean = vars(args).get('change_criterion', False) 112 | 113 | def set_flag(self, flag_name, value): 114 | setattr(self, flag_name, value) 115 | for m in self.children(): 116 | if hasattr(m, 'set_flag'): 117 | m.set_flag(flag_name, value) 118 | 119 | def forward(self, logits, target): 120 | """ 121 | Compute - kl 122 | 123 | :param logits: shape [batch_size, n_classes] 124 | :param target: shape [batch_size, n_classes] -- one-hot target 125 | :param step: 126 | :return: 127 | total loss 128 | batch_logprob term 129 | total kl term 130 | """ 131 | if not (self.mcvi or self.log_mean): 132 | logsoftmax = logsoftmax_mean(logits) 133 | elif self.mcvi: 134 | logsoftmax = F.log_softmax(logits[0], dim=1) 135 | elif self.log_mean: 136 | posterior = classification_posterior(logits) 137 | posterior = torch.clamp(posterior, 0, 1) 138 | posterior = posterior / torch.sum(posterior, dim=-1, keepdim=True) 139 | logsoftmax = torch.log(posterior + EPS) 140 | 141 | assert not target.requires_grad 142 | kl = 0.0 143 | for module in self.net.children(): 144 | if hasattr(module, 'compute_kl'): 145 | kl = kl + module.compute_kl() 146 | 147 | logprob = torch.sum(target.type(logsoftmax.type()) * logsoftmax, dim=1) 148 | batch_logprob = torch.mean(logprob) 149 | 150 | lmbda = clip((self._step - self.warmup) / self.anneal, 0, 1) 151 | L = lmbda * kl / self.data_size - batch_logprob 152 | return L, batch_logprob, kl / self.data_size, logsoftmax 153 | 154 | def step(self): 155 | self._step += 1 156 | -------------------------------------------------------------------------------- /core/models.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | 3 | 4 | class LinearDVI(nn.Module): 5 | def __init__(self, args): 6 | super().__init__() 7 | 8 | if args.nonlinearity == 'relu': 9 | layer_factory = ReluGaussian 10 | else: 11 | layer_factory = HeavisideGaussian 12 | 13 | self.fc1 = LinearGaussian(784, 300, certain=True) 14 | self.fc2 = layer_factory(300, 100) 15 | self.fc3 = layer_factory(100, 10) 16 | 17 | if args.mcvi: 18 | self.set_flag('deterministic', False) 19 | 20 | def forward(self, x): 21 | x = self.fc1(x) 22 | x = self.fc2(x) 23 | return self.fc3(x) 24 | 25 | def set_flag(self, flag_name, value): 26 | for m in self.children(): 27 | if hasattr(m, 'set_flag'): 28 | m.set_flag(flag_name, value) 29 | 30 | 31 | class LinearVariance(nn.Module): 32 | def __init__(self, args): 33 | super().__init__() 34 | 35 | if args.var1: 36 | self.fc1 = VarianceGaussian(784, 300, certain=True, 37 | sigma_sq=args.use_sqrt_sigma) 38 | else: 39 | self.fc1 = DeterministicGaussian(784, 300, certain=True) 40 | 41 | if args.var2: 42 | self.fc2 = VarianceReluGaussian(300, 100, 43 | sigma_sq=args.use_sqrt_sigma) 44 | else: 45 | self.fc2 = DeterministicReluGaussian(300, 100) 46 | 47 | if args.var3: 48 | self.fc3 = VarianceReluGaussian(100, 10, 49 | sigma_sq=args.use_sqrt_sigma) 50 | else: 51 | self.fc3 = DeterministicReluGaussian(100, 10) 52 | 53 | if args.mcvi: 54 | self.set_flag('deterministic', False) 55 | 56 | def forward(self, x): 57 | x = self.fc1(x) 58 | x = self.fc2(x) 59 | return self.fc3(x) 60 | 61 | def set_flag(self, flag_name, value): 62 | for m in self.children(): 63 | if hasattr(m, 'set_flag'): 64 | m.set_flag(flag_name, value) 65 | 66 | 67 | class LinearVDO(nn.Module): 68 | def __init__(self, args): 69 | super().__init__() 70 | self.fc1 = LinearGaussian(784, 300, certain=True) 71 | 72 | if args.nonlinearity == 'relu': 73 | layer_factory = ReluVDO 74 | else: 75 | layer_factory = HeavisideVDO 76 | 77 | self.fc2 = layer_factory(300, 100, deterministic=not args.mcvi) 78 | 79 | if args.n_layers > 1: 80 | self.fc3 = layer_factory(100, 10, deterministic=not args.mcvi) 81 | else: 82 | self.fc3 = DeterministicReluGaussian(100, 10) 83 | 84 | if args.mcvi: 85 | self.set_flag('deterministic', False) 86 | 87 | def forward(self, x): 88 | x = self.fc1(x) 89 | x = self.fc2(x) 90 | return self.fc3(x) 91 | 92 | def set_flag(self, flag_name, value): 93 | for m in self.children(): 94 | if hasattr(m, 'set_flag'): 95 | m.set_flag(flag_name, value) 96 | 97 | def zero_mean(self, mode=True): 98 | for layer in self.children(): 99 | if isinstance(layer, ReluVDO): 100 | layer.set_flag('zero_mean', mode) 101 | 102 | 103 | class LeNetDVI(nn.Module): 104 | def __init__(self, args): 105 | super().__init__() 106 | 107 | self.conv1 = MeanFieldConv2d(1, 6, 5, padding=2, certain=True) 108 | self.conv2 = MeanFieldConv2d(6, 16, 5) 109 | 110 | if args.nonlinearity == 'relu': 111 | layer_factory = ReluGaussian 112 | else: 113 | layer_factory = HeavisideGaussian 114 | 115 | self.fc1 = layer_factory(16 * 5 * 5, 120) 116 | self.fc2 = layer_factory(120, 84) 117 | self.fc3 = layer_factory(84, 10) 118 | 119 | self.avg_pool = AveragePoolGaussian(kernel_size=(2, 2)) 120 | 121 | if args.mcvi: 122 | self.set_flag('deterministic', False) 123 | 124 | def forward(self, x): 125 | x = self.avg_pool(self.conv1(x)) 126 | x = self.avg_pool(self.conv2(x)) 127 | 128 | x_mean = x[0] 129 | x_var = x[1] 130 | 131 | x_mean = x_mean.view(-1, 400) 132 | if x_var is not None: 133 | x_var = x_var.view(-1, 400) 134 | x_var = torch.diag_embed(x_var) 135 | 136 | x = (x_mean, x_var) 137 | x = self.fc1(x) 138 | x = self.fc2(x) 139 | x = self.fc3(x) 140 | return x 141 | 142 | def set_flag(self, flag_name, value): 143 | for m in self.children(): 144 | if hasattr(m, 'set_flag'): 145 | m.set_flag(flag_name, value) 146 | 147 | 148 | class LeNetVDO(nn.Module): 149 | def __init__(self, args): 150 | super().__init__() 151 | 152 | if args.vdo1: 153 | self.conv1 = MeanFieldConv2dVDO(1, 6, 5, padding=2, certain=True, 154 | deterministic=not args.mcvi, 155 | alpha_shape=(1, 1, 1, 1)) 156 | else: 157 | self.conv1 = MeanFieldConv2d(1, 6, 5, padding=2, certain=True, 158 | deterministic=not args.mcvi) 159 | 160 | if args.vdo2: 161 | self.conv2 = MeanFieldConv2dVDO(6, 16, 5, 162 | deterministic=not args.mcvi, 163 | alpha_shape=(1, 1, 1, 1)) 164 | else: 165 | self.conv2 = MeanFieldConv2d(6, 16, 5, deterministic=not args.mcvi) 166 | 167 | if args.vdo3: 168 | self.fc1 = ReluVDO(16 * 5 * 5, 120, deterministic=not args.mcvi) 169 | else: 170 | self.fc1 = DeterministicReluGaussian(16 * 5 * 5, 120, 171 | deterministic=not args.mcvi) 172 | 173 | if args.vdo4: 174 | self.fc2 = ReluVDO(120, 84, deterministic=not args.mcvi) 175 | else: 176 | self.fc2 = DeterministicReluGaussian(120, 84, 177 | deterministic=not args.mcvi) 178 | 179 | if args.vdo5: 180 | self.fc3 = ReluVDO(84, 10, deterministic=not args.mcvi) 181 | else: 182 | self.fc3 = DeterministicReluGaussian(84, 10, 183 | deterministic=not args.mcvi) 184 | 185 | self.avg_pool = AveragePoolGaussian(kernel_size=(2, 2)) 186 | 187 | if args.mcvi: 188 | self.set_flag('deterministic', False) 189 | 190 | def zero_mean(self, mode=True): 191 | for layer in self.children(): 192 | if isinstance(layer, ReluVDO) or isinstance(layer, 193 | MeanFieldConv2dVDO): 194 | if layer.log_alpha > 3 and mode: 195 | layer.set_flag('zero_mean', mode) 196 | if not mode: 197 | layer.set_flag('zero_mean', False) 198 | 199 | def forward(self, x): 200 | x = self.avg_pool(self.conv1(x)) 201 | x = self.avg_pool(self.conv2(x)) 202 | 203 | x_mean = x[0] 204 | x_var = x[1] 205 | 206 | x_mean = x_mean.view(-1, 400) 207 | if x_var is not None: 208 | x_var = x_var.view(-1, 400) 209 | x_var = torch.diag_embed(x_var) 210 | 211 | x = (x_mean, x_var) 212 | x = self.fc1(x) 213 | x = self.fc2(x) 214 | x = self.fc3(x) 215 | return x 216 | 217 | def set_flag(self, flag_name, value): 218 | for m in self.children(): 219 | if hasattr(m, 'set_flag'): 220 | m.set_flag(flag_name, value) 221 | 222 | 223 | class LeNetVariance(nn.Module): 224 | def __init__(self, args): 225 | super().__init__() 226 | 227 | if args.var1: 228 | self.conv1 = VarianceMeanFieldConv2d(1, 6, 5, padding=2, 229 | certain=True, 230 | deterministic=not args.mcvi, 231 | sigma_sq=args.use_sqrt_sigma) 232 | else: 233 | self.conv1 = MeanFieldConv2d(1, 6, 5, padding=2, certain=True, 234 | deterministic=not args.mcvi) 235 | 236 | if args.var2: 237 | self.conv2 = VarianceMeanFieldConv2d(6, 16, 5, 238 | deterministic=not args.mcvi, 239 | sigma_sq=args.use_sqrt_sigma) 240 | else: 241 | self.conv2 = MeanFieldConv2d(6, 16, 5, deterministic=not args.mcvi) 242 | 243 | if args.var3: 244 | self.fc1 = VarianceReluGaussian(16 * 5 * 5, 120, 245 | deterministic=not args.mcvi, 246 | sigma_sq=args.use_sqrt_sigma) 247 | else: 248 | self.fc1 = DeterministicReluGaussian(16 * 5 * 5, 120, 249 | deterministic=not args.mcvi) 250 | 251 | if args.var4: 252 | self.fc2 = VarianceReluGaussian(120, 84, 253 | deterministic=not args.mcvi, 254 | sigma_sq=args.use_sqrt_sigma) 255 | else: 256 | self.fc2 = DeterministicReluGaussian(120, 84, 257 | deterministic=not args.mcvi) 258 | 259 | if args.var5: 260 | self.fc3 = VarianceReluGaussian(84, 10, deterministic=not args.mcvi, 261 | sigma_sq=args.use_sqrt_sigma) 262 | else: 263 | self.fc3 = DeterministicReluGaussian(84, 10, 264 | deterministic=not args.mcvi) 265 | 266 | self.avg_pool = AveragePoolGaussian(kernel_size=(2, 2)) 267 | 268 | if args.mcvi: 269 | self.set_flag('deterministic', False) 270 | 271 | def forward(self, x): 272 | x = self.avg_pool(self.conv1(x)) 273 | x = self.avg_pool(self.conv2(x)) 274 | 275 | x_mean = x[0] 276 | x_var = x[1] 277 | 278 | x_mean = x_mean.view(-1, 400) 279 | if x_var is not None: 280 | x_var = x_var.view(-1, 400) 281 | x_var = torch.diag_embed(x_var) 282 | 283 | x = (x_mean, x_var) 284 | x = self.fc1(x) 285 | x = self.fc2(x) 286 | x = self.fc3(x) 287 | return x 288 | 289 | def set_flag(self, flag_name, value): 290 | for m in self.children(): 291 | if hasattr(m, 'set_flag'): 292 | m.set_flag(flag_name, value) 293 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from sklearn.datasets import make_classification, make_circles 9 | from torchvision import datasets, transforms 10 | 11 | from .bayesian_utils import sample_softmax, classification_posterior 12 | 13 | 14 | def one_hot_encoding(tensor, n_classes, device): 15 | ohe = torch.LongTensor(tensor.size(0), n_classes).to(device) 16 | ohe.zero_() 17 | ohe.scatter_(1, tensor, 1) 18 | return ohe 19 | 20 | 21 | def pred2acc(pred, y, batch_size): 22 | t = torch.sum(torch.squeeze(pred) == torch.squeeze(y), dtype=torch.float32) 23 | return (t / batch_size).item() 24 | 25 | 26 | def draw_regression_result(data, data_generator, predictions=None, name=None): 27 | x = predictions[0] 28 | train_x = np.arange(np.min(x.reshape(-1)), 29 | np.max(x.reshape(-1)), 1 / 100) 30 | 31 | # plot the training data distribution 32 | plt.figure(figsize=(14, 10)) 33 | 34 | plt.plot(train_x, data_generator['mean'](train_x), 'red', label='data mean') 35 | plt.fill_between(train_x, 36 | data_generator['mean'](train_x) - data_generator['std']( 37 | train_x), 38 | data_generator['mean'](train_x) + data_generator['std']( 39 | train_x), 40 | color='orange', alpha=1, label='data 1-std') 41 | plt.plot(data[0][0], data[0][1], 'r.', alpha=0.2, label='train sampl') 42 | 43 | # plot the model distribution 44 | if predictions is not None: 45 | x = predictions[0] 46 | 47 | y_mean = predictions[1]['mean'][:, 0] 48 | ell_mean = 2 * math.log(0.2) 49 | y_var = predictions[1]['cov'][:, 0, 0] 50 | ell_var = 0 51 | 52 | heteroskedastic_part = np.exp(0.5 * ell_mean) 53 | full_std = np.sqrt(y_var + np.exp(ell_mean + 0.5 * ell_var)) 54 | 55 | plt.scatter(x, y_mean, label='model mean') 56 | plt.scatter(x, y_mean - heteroskedastic_part, color='g', alpha=0.2) 57 | plt.scatter(x, y_mean + heteroskedastic_part, color='g', alpha=0.2, 58 | label='$\ell$ contrib') 59 | 60 | plt.scatter(x, y_mean - full_std, color='b', alpha=0.2, 61 | label='model 1-std') 62 | plt.scatter(x, y_mean + full_std, color='b', alpha=0.2) 63 | plt.xlabel('x') 64 | plt.ylabel('y') 65 | plt.ylim([-3, 2]) 66 | plt.legend() 67 | plt.savefig(name) 68 | plt.close() 69 | 70 | 71 | def get_predictions(data, model, args, mcvi=False): 72 | output = model(data) 73 | 74 | output_cov = output[1] 75 | output_mean = output[0] 76 | 77 | n = output_mean.size(0) 78 | 79 | if not mcvi: 80 | m = output_mean.size(1) 81 | 82 | out_cov = torch.reshape(output_cov, (n, m, m)) 83 | out_mean = output_mean 84 | 85 | x = data.cpu().detach().numpy().squeeze() 86 | y = {'mean': out_mean.cpu().detach().numpy(), 87 | 'cov': out_cov.cpu().detach().numpy()} 88 | else: 89 | out_mean = output_mean.unsqueeze(-1) 90 | out_cov = torch.zeros_like(out_mean).unsqueeze(-1) 91 | 92 | x = data.cpu().detach().numpy().squeeze() 93 | y = {'mean': out_mean.cpu().detach().numpy(), 94 | 'cov': out_cov.cpu().detach().numpy()} 95 | return x, y 96 | 97 | 98 | def generate_regression_data(args, sample_data, base_model): 99 | data_size = {'train': args.data_size, 'test': args.test_size} 100 | toy_data = [] 101 | for section in ['train', 'test']: 102 | x = (np.random.rand(data_size['train'], 1) - 0.5) 103 | toy_data.append([x, sample_data(x, args).reshape(-1)]) 104 | x = np.arange(-1, 1, 1 / 100) 105 | toy_data.append([[[_] for _ in x], base_model(x)]) 106 | x_train = toy_data[0][0] 107 | y_train = toy_data[0][1] 108 | 109 | x_test = toy_data[1][0] 110 | y_test = toy_data[1][1] 111 | 112 | x_train = torch.FloatTensor(x_train).to(device=args.device) 113 | y_train = torch.FloatTensor(y_train).to(device=args.device) 114 | 115 | x_test = torch.FloatTensor(x_test).to(device=args.device) 116 | y_test = torch.FloatTensor(y_test).to(device=args.device) 117 | 118 | return x_train, y_train, x_test, y_test, toy_data 119 | 120 | 121 | def generate_classification_data(args): 122 | if args.dataset.strip().lower() == 'classification': 123 | n_informative = int(args.input_size * 0.8) 124 | x_train, y_train = make_classification(n_samples=args.data_size, 125 | n_features=args.input_size, 126 | n_informative=n_informative, 127 | n_redundant=args.input_size - n_informative, 128 | n_classes=args.n_classes) 129 | 130 | x_test, y_test = make_classification(n_samples=args.test_size, 131 | n_features=args.input_size, 132 | n_informative=n_informative, 133 | n_redundant=args.input_size - n_informative, 134 | n_classes=args.n_classes) 135 | n_classes = args.n_classes 136 | elif args.dataset.strip().lower() == 'circles': 137 | x_train, y_train = make_circles(n_samples=args.data_size) 138 | x_test, y_test = make_circles(n_samples=args.test_size) 139 | n_classes = 2 140 | 141 | x_train = torch.FloatTensor(x_train).to(args.device) 142 | y_train = torch.LongTensor(y_train).to(args.device).view(-1, 1) 143 | 144 | y_onehot_train = one_hot_encoding(y_train, n_classes, args.device) 145 | 146 | x_test = torch.FloatTensor(x_test).to(args.device) 147 | y_test = torch.LongTensor(y_test).to(args.device).view(-1, 1) 148 | 149 | y_onehot_test = one_hot_encoding(y_test, n_classes, args.device) 150 | 151 | return x_train, y_train, y_onehot_train, x_test, y_test, y_onehot_test 152 | 153 | 154 | def draw_classification_results(data, prediction, name, args): 155 | """ 156 | 157 | :param data: input to draw, should be 2D 158 | :param prediction: predicted class labels 159 | :param name: output file name 160 | :return: 161 | """ 162 | 163 | x = data.detach().cpu().numpy() 164 | y = prediction.detach().cpu().numpy().squeeze() 165 | 166 | plt.figure(figsize=(10, 8)) 167 | plt.scatter(x[:, 0], x[:, 1], c=y) 168 | 169 | if args.dataset == "circles": 170 | path = 'pics/classification/circles' 171 | else: 172 | path = 'pics/classification/cls' 173 | 174 | if not os.path.exists(path): 175 | os.mkdir(path) 176 | 177 | filename = path + os.sep + name 178 | plt.savefig(filename) 179 | plt.close() 180 | 181 | 182 | def load_mnist(args): 183 | train_loader = torch.utils.data.DataLoader( 184 | datasets.MNIST('../../data', train=True, download=True, 185 | transform=transforms.Compose([ 186 | transforms.ToTensor(), 187 | transforms.Normalize((0.1307,), (0.3081,)) 188 | ])), 189 | batch_size=args.batch_size, shuffle=True, num_workers=0) 190 | 191 | test_loader = torch.utils.data.DataLoader( 192 | datasets.MNIST('../../data', train=False, transform=transforms.Compose([ 193 | transforms.ToTensor(), 194 | transforms.Normalize((0.1307,), (0.3081,)) 195 | ])), 196 | batch_size=args.test_batch_size, shuffle=True, num_workers=0) 197 | return train_loader, test_loader 198 | 199 | 200 | def save_checkpoint(state, dir, filename): 201 | torch.save(state, os.path.join(dir, filename)) 202 | 203 | 204 | def load_checkpoint(experiment, name='last.pth.tar'): 205 | filename = os.path.join('checkpoints', experiment, name) 206 | checkpoint = torch.load(filename) 207 | return checkpoint['state_dict'] 208 | 209 | 210 | def mc_prediction(model, input, n_samples): 211 | logits = torch.stack([model(input)[0] for _ in range(n_samples)], dim=0) 212 | probs = F.softmax(logits, dim=-1) 213 | mean_probs = torch.mean(probs, dim=0) 214 | return mean_probs 215 | 216 | 217 | def evaluate(model, loader, mode, args, zero_mean=False): 218 | prev_mcvi, prev_samples = args.mcvi, args.use_samples 219 | accuracy = [] 220 | 221 | if mode == 'mcvi': 222 | model.set_flag('deterministic', False) 223 | args.mcvi = True 224 | if mode == 'dvi': 225 | model.set_flag('deterministic', True) 226 | args.mcvi = False 227 | if mode == 'samples_dvi': 228 | model.set_flag('deterministic', True) 229 | args.use_samples = True 230 | 231 | if zero_mean: 232 | model.zero_mean() 233 | 234 | with torch.no_grad(): 235 | for data, y_test in loader: 236 | x = data.to(args.device) 237 | if args.reshape: 238 | x = x.view(-1, 784) 239 | 240 | y = y_test.to(args.device) 241 | 242 | if mode == 'mcvi': 243 | probs = mc_prediction(model, x, args.mc_samples) 244 | elif mode == 'samples_dvi': 245 | activations = model(x) 246 | probs = sample_softmax(activations, n_samples=args.mc_samples) 247 | elif mode == 'dvi': 248 | activations = model(x) 249 | probs = classification_posterior(activations) 250 | else: 251 | raise ValueError('invalid mode for evaluate') 252 | 253 | pred = torch.argmax(probs, dim=1) 254 | accuracy.append(pred2acc(pred, y, args.test_batch_size)) 255 | 256 | args.mcvi = prev_mcvi 257 | args.use_samples = prev_samples 258 | if prev_mcvi: 259 | model.set_flag('mcvi', True) 260 | else: 261 | model.set_flag('deterministic', True) 262 | 263 | if zero_mean: 264 | model.zero_mean(False) 265 | 266 | return np.mean(accuracy) 267 | -------------------------------------------------------------------------------- /pics/classification/circles/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovalexander/DVI/76d1c2261e48d5d804af50b9037c6cd650eb95c2/pics/classification/circles/test.png -------------------------------------------------------------------------------- /pics/classification/circles/train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovalexander/DVI/76d1c2261e48d5d804af50b9037c6cd650eb95c2/pics/classification/circles/train.png -------------------------------------------------------------------------------- /pics/regression/det/last.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovalexander/DVI/76d1c2261e48d5d804af50b9037c6cd650eb95c2/pics/regression/det/last.png -------------------------------------------------------------------------------- /pics/regression/det/swapped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovalexander/DVI/76d1c2261e48d5d804af50b9037c6cd650eb95c2/pics/regression/det/swapped.png -------------------------------------------------------------------------------- /pics/regression/mcvi/last.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovalexander/DVI/76d1c2261e48d5d804af50b9037c6cd650eb95c2/pics/regression/mcvi/last.png -------------------------------------------------------------------------------- /pics/regression/mcvi/swapped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovalexander/DVI/76d1c2261e48d5d804af50b9037c6cd650eb95c2/pics/regression/mcvi/swapped.png -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovalexander/DVI/76d1c2261e48d5d804af50b9037c6cd650eb95c2/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/fc-variance.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from time import time 3 | 4 | import numpy as np 5 | 6 | from core.logger import Logger 7 | from core.losses import ClassificationLoss 8 | from core.models import * 9 | from core.utils import load_mnist, one_hot_encoding, evaluate, pred2acc 10 | 11 | np.random.seed(42) 12 | 13 | EPS = 1e-6 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--device', type=int, default=0) 18 | parser.add_argument('--anneal_updates', type=int, default=1) 19 | parser.add_argument('--warmup_updates', type=int, default=0) 20 | parser.add_argument('--mcvi', action='store_true') 21 | parser.add_argument('--mc_samples', default=20, type=int) 22 | parser.add_argument('--clip_grad', type=float, default=0.5) 23 | parser.add_argument('--lr', type=float, default=1e-3) 24 | parser.add_argument('--epochs', type=int, default=150) 25 | parser.add_argument('--zm', action='store_true') 26 | parser.add_argument('--no_mc', action='store_true') 27 | parser.add_argument('--use_samples', action='store_true') 28 | parser.add_argument('--reshape', action='store_true', default=True) 29 | 30 | parser.add_argument('--var1', action='store_true') 31 | parser.add_argument('--var2', action='store_true') 32 | parser.add_argument('--var3', action='store_true') 33 | 34 | fmt = {'kl': '3.3e', 35 | 'tr_elbo': '3.3e', 36 | 'tr_acc': '.4f', 37 | 'tr_ll': '.3f', 38 | 'te_acc_dvi': '.4f', 39 | 'te_acc_mcvi': '.4f', 40 | 'te_acc_samples': '.4f', 41 | 'te_acc_dvi_zm': '.4f', 42 | 'te_acc_mcvi_zm': '.4f', 43 | 'tr_time': '.3f', 44 | 'te_time_dvi': '.3f', 45 | 'te_time_mcvi': '.3f'} 46 | 47 | if __name__ == "__main__": 48 | args = parser.parse_args() 49 | args.device = torch.device( 50 | 'cuda:{}'.format(args.device) if torch.cuda.is_available() else 'cpu') 51 | 52 | if args.var3: 53 | args.change_criterion = True 54 | model = LinearVariance(args).to(args.device) 55 | 56 | args.batch_size, args.test_batch_size = 32, 32 57 | train_loader, test_loader = load_mnist(args) 58 | args.data_size = len(train_loader.dataset) 59 | 60 | logger = Logger('lenet-variance', fmt=fmt) 61 | logger.print(args) 62 | logger.print(model) 63 | 64 | criterion = ClassificationLoss(model, args) 65 | optimizer = torch.optim.Adam([p for p in model.parameters() 66 | if p.requires_grad], lr=args.lr) 67 | 68 | for epoch in range(args.epochs): 69 | t0 = time() 70 | 71 | model.train() 72 | criterion.step() 73 | 74 | elbo, cat_mean, kls, accuracy = [], [], [], [] 75 | for data, y_train in train_loader: 76 | optimizer.zero_grad() 77 | 78 | x_train = data.view(-1, 784).to(args.device) 79 | y_train = y_train.to(args.device) 80 | 81 | y_ohe = one_hot_encoding(y_train[:, None], 10, args.device) 82 | y_logits = model(x_train) 83 | 84 | loss, categorical_mean, kl, logsoftmax = criterion(y_logits, 85 | y_ohe) 86 | 87 | pred = torch.argmax(logsoftmax, dim=1) 88 | loss.backward() 89 | 90 | if args.clip_grad > 0: 91 | nn.utils.clip_grad.clip_grad_value_(model.parameters(), 92 | args.clip_grad) 93 | 94 | optimizer.step() 95 | 96 | elbo.append(-loss.item()) 97 | cat_mean.append(categorical_mean.item()) 98 | kls.append(kl.item() if isinstance(kl, torch.Tensor) else kl) 99 | accuracy.append(pred2acc(pred, y_train, args.batch_size)) 100 | 101 | t1 = time() - t0 102 | elbo = np.mean(elbo) 103 | cat_mean = np.mean(cat_mean) 104 | kl = np.mean(kls) 105 | accuracy = np.mean(accuracy) 106 | logger.add(epoch, kl=kl, tr_elbo=elbo, tr_acc=accuracy, tr_ll=cat_mean, 107 | tr_time=t1) 108 | 109 | model.eval() 110 | t_dvi = time() 111 | test_acc_dvi = evaluate(model, test_loader, mode='dvi', args=args) 112 | t_dvi = time() - t_dvi 113 | 114 | if not args.no_mc: 115 | t_mc = time() 116 | test_acc_mcvi = evaluate(model, test_loader, mode='mcvi', args=args) 117 | t_mc = time() - t_mc 118 | logger.add(epoch, te_acc_mcvi=test_acc_mcvi, te_time_mcvi=t_mc) 119 | 120 | test_acc_samples = evaluate(model, test_loader, mode='samples_dvi', 121 | args=args) 122 | logger.add(epoch, te_acc_dvi=test_acc_dvi, 123 | te_acc_samples=test_acc_samples, te_time_dvi=t_dvi) 124 | 125 | logger.iter_info() 126 | logger.save(silent=True) 127 | torch.save(model.state_dict(), logger.checkpoint) 128 | logger.save() 129 | -------------------------------------------------------------------------------- /scripts/lenet-det.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from time import time 3 | 4 | import numpy as np 5 | 6 | from core.logger import Logger 7 | from core.losses import ClassificationLoss 8 | from core.models import * 9 | from core.utils import load_mnist, one_hot_encoding, evaluate, pred2acc 10 | 11 | np.random.seed(42) 12 | 13 | EPS = 1e-6 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--device', type=int, default=0) 18 | parser.add_argument('--anneal_updates', type=int, default=1) 19 | parser.add_argument('--warmup_updates', type=int, default=0) 20 | parser.add_argument('--mcvi', action='store_true') 21 | parser.add_argument('--mc_samples', default=20, type=int) 22 | parser.add_argument('--clip_grad', type=float, default=0.5) 23 | parser.add_argument('--lr', type=float, default=1e-3) 24 | parser.add_argument('--epochs', type=int, default=150) 25 | parser.add_argument('--zm', action='store_true') 26 | parser.add_argument('--no_mc', action='store_true') 27 | parser.add_argument('--use_samples', action='store_true') 28 | parser.add_argument('--reshape', action='store_true', default=False) 29 | 30 | 31 | fmt = {'kl': '3.3e', 32 | 'tr_elbo': '3.3e', 33 | 'tr_acc': '.4f', 34 | 'tr_ll': '.3f', 35 | 'te_acc_dvi': '.4f', 36 | 'te_acc_mcvi': '.4f', 37 | 'te_acc_samples': '.4f', 38 | 'te_acc_dvi_zm': '.4f', 39 | 'te_acc_mcvi_zm': '.4f', 40 | 'tr_time': '.3f', 41 | 'te_time_dvi': '.3f', 42 | 'te_time_mcvi': '.3f'} 43 | 44 | if __name__ == "__main__": 45 | args = parser.parse_args() 46 | args.device = torch.device( 47 | 'cuda:{}'.format(args.device) if torch.cuda.is_available() else 'cpu') 48 | 49 | model = LeNetDVI(args).to(args.device) 50 | 51 | args.batch_size, args.test_batch_size = 32, 32 52 | train_loader, test_loader = load_mnist(args) 53 | args.data_size = len(train_loader.dataset) 54 | 55 | logger = Logger('lenet-deterministic', fmt=fmt) 56 | logger.print(args) 57 | logger.print(model) 58 | 59 | criterion = ClassificationLoss(model, args) 60 | optimizer = torch.optim.Adam([p for p in model.parameters() 61 | if p.requires_grad], lr=args.lr) 62 | 63 | for epoch in range(args.epochs): 64 | t0 = time() 65 | 66 | model.train() 67 | criterion.step() 68 | 69 | elbo, cat_mean, kls, accuracy = [], [], [], [] 70 | for data, y_train in train_loader: 71 | optimizer.zero_grad() 72 | 73 | x_train = data.to(args.device) 74 | y_train = y_train.to(args.device) 75 | 76 | y_ohe = one_hot_encoding(y_train[:, None], 10, args.device) 77 | y_logits = model(x_train) 78 | 79 | loss, categorical_mean, kl, logsoftmax = criterion(y_logits, 80 | y_ohe) 81 | 82 | pred = torch.argmax(logsoftmax, dim=1) 83 | loss.backward() 84 | 85 | if args.clip_grad > 0: 86 | nn.utils.clip_grad.clip_grad_value_(model.parameters(), 87 | args.clip_grad) 88 | 89 | optimizer.step() 90 | 91 | elbo.append(-loss.item()) 92 | cat_mean.append(categorical_mean.item()) 93 | kls.append(kl.item() if isinstance(kl, torch.Tensor) else kl) 94 | accuracy.append(pred2acc(pred, y_train, args.batch_size)) 95 | 96 | t1 = time() - t0 97 | elbo = np.mean(elbo) 98 | cat_mean = np.mean(cat_mean) 99 | kl = np.mean(kls) 100 | accuracy = np.mean(accuracy) 101 | logger.add(epoch, kl=kl, tr_elbo=elbo, tr_acc=accuracy, tr_ll=cat_mean, 102 | tr_time=t1) 103 | 104 | model.eval() 105 | t_dvi = time() 106 | test_acc_dvi = evaluate(model, test_loader, mode='dvi', args=args) 107 | t_dvi = time() - t_dvi 108 | 109 | if not args.no_mc: 110 | t_mc = time() 111 | test_acc_mcvi = evaluate(model, test_loader, mode='mcvi', args=args) 112 | t_mc = time() - t_mc 113 | logger.add(epoch, te_acc_mcvi=test_acc_mcvi, te_time_mcvi=t_mc) 114 | 115 | test_acc_samples = evaluate(model, test_loader, mode='samples_dvi', 116 | args=args) 117 | logger.add(epoch, te_acc_dvi=test_acc_dvi, 118 | te_acc_samples=test_acc_samples, te_time_dvi=t_dvi) 119 | 120 | logger.iter_info() 121 | logger.save(silent=True) 122 | torch.save(model.state_dict(), logger.checkpoint) 123 | logger.save() 124 | -------------------------------------------------------------------------------- /scripts/lenet-variance-first.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from time import time 3 | 4 | import numpy as np 5 | 6 | from core.logger import Logger 7 | from core.losses import ClassificationLoss 8 | from core.models import * 9 | from core.utils import load_mnist, one_hot_encoding, evaluate, pred2acc 10 | 11 | np.random.seed(42) 12 | 13 | EPS = 1e-6 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--device', type=int, default=0) 18 | parser.add_argument('--anneal_updates', type=int, default=1) 19 | parser.add_argument('--warmup_updates', type=int, default=0) 20 | parser.add_argument('--mcvi', action='store_true') 21 | parser.add_argument('--mc_samples', default=20, type=int) 22 | parser.add_argument('--clip_grad', type=float, default=0.5) 23 | parser.add_argument('--lr', type=float, default=1e-3) 24 | parser.add_argument('--epochs', type=int, default=150) 25 | parser.add_argument('--zm', action='store_true') 26 | parser.add_argument('--no_mc', action='store_true') 27 | parser.add_argument('--use_samples', action='store_true') 28 | parser.add_argument('--reshape', action='store_true', default=False) 29 | parser.add_argument('--use-sqrt-sigma', action='store_true', default=False) 30 | 31 | parser.add_argument('--var2', action='store_true') 32 | parser.add_argument('--var3', action='store_true') 33 | parser.add_argument('--var4', action='store_true') 34 | parser.add_argument('--var5', action='store_true', default=False) 35 | 36 | fmt = {'kl': '3.3e', 37 | 'tr_elbo': '3.3e', 38 | 'tr_acc': '.4f', 39 | 'tr_ll': '.3f', 40 | 'te_acc_dvi': '.4f', 41 | 'te_acc_mcvi': '.4f', 42 | 'te_acc_samples': '.4f', 43 | 'te_acc_dvi_zm': '.4f', 44 | 'te_acc_mcvi_zm': '.4f', 45 | 'tr_time': '.3f', 46 | 'te_time_dvi': '.3f', 47 | 'te_time_mcvi': '.3f'} 48 | 49 | 50 | def forward_hook(self, input, output): 51 | if isinstance(output, tuple): 52 | if torch.any(torch.isnan(output[0])): 53 | print(self.__cls__.__name__ + 'forward output_mean nan') 54 | if output[1] is not None and torch.any(torch.isnan(output[1])): 55 | print(self.__cls__.__name__ + 'forward output_var nan') 56 | 57 | 58 | def backward_hook(self, grad_input, grad_output): 59 | if torch.any(torch.isnan(grad_output[0])): 60 | print(self().__cls__.__name__ + 'backward grad_output nan') 61 | 62 | 63 | if __name__ == "__main__": 64 | args = parser.parse_args() 65 | args.device = torch.device( 66 | 'cuda:{}'.format(args.device) if torch.cuda.is_available() else 'cpu') 67 | 68 | args.var1 = True 69 | model = LeNetVariance(args).to(args.device) 70 | 71 | # for layer in model.children(): 72 | # layer.register_forward_hook(forward_hook) 73 | # layer.register_backward_hook(backward_hook) 74 | 75 | args.batch_size, args.test_batch_size = 32, 32 76 | train_loader, test_loader = load_mnist(args) 77 | args.data_size = len(train_loader.dataset) 78 | 79 | logger = Logger('lenet-variance', fmt=fmt) 80 | logger.print(args) 81 | logger.print(model) 82 | 83 | criterion = ClassificationLoss(model, args) 84 | optimizer = torch.optim.Adam([p for p in model.parameters() 85 | if p.requires_grad], lr=args.lr) 86 | 87 | for epoch in range(args.epochs): 88 | t0 = time() 89 | 90 | model.train() 91 | criterion.step() 92 | 93 | elbo, cat_mean, kls, accuracy = [], [], [], [] 94 | for data, y_train in train_loader: 95 | optimizer.zero_grad() 96 | 97 | x_train = data.to(args.device) 98 | y_train = y_train.to(args.device) 99 | 100 | y_ohe = one_hot_encoding(y_train[:, None], 10, args.device) 101 | y_logits = model(x_train) 102 | 103 | loss, categorical_mean, kl, logsoftmax = criterion(y_logits, 104 | y_ohe) 105 | 106 | pred = torch.argmax(logsoftmax, dim=1) 107 | loss.backward() 108 | 109 | if args.clip_grad > 0: 110 | nn.utils.clip_grad.clip_grad_value_(model.parameters(), 111 | args.clip_grad) 112 | 113 | optimizer.step() 114 | 115 | elbo.append(-loss.item()) 116 | cat_mean.append(categorical_mean.item()) 117 | kls.append(kl.item() if isinstance(kl, torch.Tensor) else kl) 118 | accuracy.append(pred2acc(pred, y_train, args.batch_size)) 119 | 120 | t1 = time() - t0 121 | elbo = np.mean(elbo) 122 | cat_mean = np.mean(cat_mean) 123 | kl = np.mean(kls) 124 | accuracy = np.mean(accuracy) 125 | logger.add(epoch, kl=kl, tr_elbo=elbo, tr_acc=accuracy, tr_ll=cat_mean, 126 | tr_time=t1) 127 | 128 | model.eval() 129 | t_dvi = time() 130 | test_acc_dvi = evaluate(model, test_loader, mode='dvi', args=args) 131 | t_dvi = time() - t_dvi 132 | 133 | if not args.no_mc: 134 | t_mc = time() 135 | test_acc_mcvi = evaluate(model, test_loader, mode='mcvi', args=args) 136 | t_mc = time() - t_mc 137 | logger.add(epoch, te_acc_mcvi=test_acc_mcvi, te_time_mcvi=t_mc) 138 | 139 | logger.add(epoch, te_acc_dvi=test_acc_dvi, te_time_dvi=t_dvi) 140 | 141 | logger.iter_info() 142 | logger.save(silent=True) 143 | torch.save(model.state_dict(), logger.checkpoint) 144 | logger.save() 145 | -------------------------------------------------------------------------------- /scripts/lenet-variance-last.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from time import time 3 | 4 | import numpy as np 5 | 6 | from core.logger import Logger 7 | from core.losses import ClassificationLoss 8 | from core.models import * 9 | from core.utils import load_mnist, one_hot_encoding, evaluate, pred2acc 10 | 11 | np.random.seed(42) 12 | 13 | EPS = 1e-6 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--device', type=int, default=0) 18 | parser.add_argument('--anneal_updates', type=int, default=1) 19 | parser.add_argument('--warmup_updates', type=int, default=0) 20 | parser.add_argument('--mcvi', action='store_true') 21 | parser.add_argument('--mc_samples', default=20, type=int) 22 | parser.add_argument('--clip_grad', type=float, default=0.5) 23 | parser.add_argument('--lr', type=float, default=1e-3) 24 | parser.add_argument('--epochs', type=int, default=150) 25 | parser.add_argument('--zm', action='store_true') 26 | parser.add_argument('--no_mc', action='store_true') 27 | parser.add_argument('--use_samples', action='store_true') 28 | parser.add_argument('--reshape', action='store_true', default=False) 29 | parser.add_argument('--use-sqrt-sigma', action='store_true', default=False) 30 | 31 | parser.add_argument('--var1', action='store_true') 32 | parser.add_argument('--var2', action='store_true') 33 | parser.add_argument('--var3', action='store_true') 34 | parser.add_argument('--var4', action='store_true') 35 | 36 | fmt = {'kl': '3.3e', 37 | 'tr_elbo': '3.3e', 38 | 'tr_acc': '.4f', 39 | 'tr_ll': '.3f', 40 | 'te_acc_dvi': '.4f', 41 | 'te_acc_mcvi': '.4f', 42 | 'te_acc_samples': '.4f', 43 | 'te_acc_dvi_zm': '.4f', 44 | 'te_acc_mcvi_zm': '.4f', 45 | 'tr_time': '.3f', 46 | 'te_time_dvi': '.3f', 47 | 'te_time_mcvi': '.3f'} 48 | 49 | 50 | def forward_hook(self, input, output): 51 | if isinstance(output, tuple): 52 | if torch.any(torch.isnan(output[0])): 53 | print(self.__cls__.__name__ + 'forward output_mean nan') 54 | if output[1] is not None and torch.any(torch.isnan(output[1])): 55 | print(self.__cls__.__name__ + 'forward output_var nan') 56 | 57 | 58 | def backward_hook(self, grad_input, grad_output): 59 | if torch.any(torch.isnan(grad_output[0])): 60 | print(self.__cls__.__name__ + 'backward grad_output nan') 61 | 62 | 63 | if __name__ == "__main__": 64 | args = parser.parse_args() 65 | args.device = torch.device( 66 | 'cuda:{}'.format(args.device) if torch.cuda.is_available() else 'cpu') 67 | 68 | args.var5 = True 69 | args.change_criterion = True 70 | model = LeNetVariance(args).to(args.device) 71 | 72 | args.batch_size, args.test_batch_size = 32, 32 73 | train_loader, test_loader = load_mnist(args) 74 | args.data_size = len(train_loader.dataset) 75 | 76 | # for layer in model.children(): 77 | # layer.register_forward_hook(forward_hook) 78 | # layer.register_backward_hook(backward_hook) 79 | 80 | logger = Logger('lenet-variance', fmt=fmt) 81 | logger.print(args) 82 | logger.print(model) 83 | 84 | criterion = ClassificationLoss(model, args) 85 | optimizer = torch.optim.Adam([p for p in model.parameters() 86 | if p.requires_grad], lr=args.lr) 87 | 88 | for epoch in range(args.epochs): 89 | t0 = time() 90 | 91 | model.train() 92 | criterion.step() 93 | 94 | elbo, cat_mean, kls, accuracy = [], [], [], [] 95 | for data, y_train in train_loader: 96 | optimizer.zero_grad() 97 | 98 | x_train = data.to(args.device) 99 | y_train = y_train.to(args.device) 100 | 101 | y_ohe = one_hot_encoding(y_train[:, None], 10, args.device) 102 | y_logits = model(x_train) 103 | 104 | loss, categorical_mean, kl, logsoftmax = criterion(y_logits, 105 | y_ohe) 106 | 107 | pred = torch.argmax(logsoftmax, dim=1) 108 | loss.backward() 109 | 110 | if args.clip_grad > 0: 111 | nn.utils.clip_grad.clip_grad_value_(model.parameters(), 112 | args.clip_grad) 113 | 114 | optimizer.step() 115 | 116 | elbo.append(-loss.item()) 117 | cat_mean.append(categorical_mean.item()) 118 | kls.append(kl.item() if isinstance(kl, torch.Tensor) else kl) 119 | accuracy.append(pred2acc(pred, y_train, args.batch_size)) 120 | 121 | t1 = time() - t0 122 | elbo = np.mean(elbo) 123 | cat_mean = np.mean(cat_mean) 124 | kl = np.mean(kls) 125 | accuracy = np.mean(accuracy) 126 | logger.add(epoch, kl=kl, tr_elbo=elbo, tr_acc=accuracy, tr_ll=cat_mean, 127 | tr_time=t1) 128 | 129 | model.eval() 130 | t_dvi = time() 131 | test_acc_dvi = evaluate(model, test_loader, mode='dvi', args=args) 132 | t_dvi = time() - t_dvi 133 | 134 | if not args.no_mc: 135 | t_mc = time() 136 | test_acc_mcvi = evaluate(model, test_loader, mode='mcvi', args=args) 137 | t_mc = time() - t_mc 138 | logger.add(epoch, te_acc_mcvi=test_acc_mcvi, te_time_mcvi=t_mc) 139 | 140 | test_acc_samples = evaluate(model, test_loader, mode='samples_dvi', 141 | args=args) 142 | logger.add(epoch, te_acc_dvi=test_acc_dvi, 143 | te_acc_samples=test_acc_samples, te_time_dvi=t_dvi) 144 | 145 | logger.iter_info() 146 | logger.save(silent=True) 147 | torch.save(model.state_dict(), logger.checkpoint) 148 | logger.save() 149 | -------------------------------------------------------------------------------- /scripts/lenet-variance.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from time import time 3 | 4 | import numpy as np 5 | 6 | from core.logger import Logger 7 | from core.losses import ClassificationLoss 8 | from core.models import * 9 | from core.utils import load_mnist, one_hot_encoding, evaluate, pred2acc 10 | 11 | np.random.seed(42) 12 | 13 | EPS = 1e-6 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--device', type=int, default=0) 18 | parser.add_argument('--anneal_updates', type=int, default=1) 19 | parser.add_argument('--warmup_updates', type=int, default=0) 20 | parser.add_argument('--mcvi', action='store_true') 21 | parser.add_argument('--mc_samples', default=20, type=int) 22 | parser.add_argument('--clip_grad', type=float, default=0.5) 23 | parser.add_argument('--lr', type=float, default=1e-3) 24 | parser.add_argument('--epochs', type=int, default=150) 25 | parser.add_argument('--zm', action='store_true') 26 | parser.add_argument('--no_mc', action='store_true') 27 | parser.add_argument('--use_samples', action='store_true') 28 | parser.add_argument('--reshape', action='store_true', default=False) 29 | parser.add_argument('--use-sqrt-sigma', action='store_true', default=False) 30 | 31 | parser.add_argument('--var1', action='store_true') 32 | parser.add_argument('--var2', action='store_true') 33 | parser.add_argument('--var3', action='store_true') 34 | parser.add_argument('--var4', action='store_true') 35 | 36 | fmt = {'kl': '3.3e', 37 | 'tr_elbo': '3.3e', 38 | 'tr_acc': '.4f', 39 | 'tr_ll': '.3f', 40 | 'te_acc_dvi': '.4f', 41 | 'te_acc_mcvi': '.4f', 42 | 'te_acc_samples': '.4f', 43 | 'te_acc_dvi_zm': '.4f', 44 | 'te_acc_mcvi_zm': '.4f', 45 | 'tr_time': '.3f', 46 | 'te_time_dvi': '.3f', 47 | 'te_time_mcvi': '.3f'} 48 | 49 | if __name__ == "__main__": 50 | args = parser.parse_args() 51 | args.device = torch.device( 52 | 'cuda:{}'.format(args.device) if torch.cuda.is_available() else 'cpu') 53 | 54 | args.var5 = False 55 | model = LeNetVariance(args).to(args.device) 56 | 57 | args.batch_size, args.test_batch_size = 32, 32 58 | train_loader, test_loader = load_mnist(args) 59 | args.data_size = len(train_loader.dataset) 60 | 61 | logger = Logger('lenet-variance', fmt=fmt) 62 | logger.print(args) 63 | logger.print(model) 64 | 65 | criterion = ClassificationLoss(model, args) 66 | optimizer = torch.optim.Adam([p for p in model.parameters() 67 | if p.requires_grad], lr=args.lr) 68 | 69 | for epoch in range(args.epochs): 70 | t0 = time() 71 | 72 | model.train() 73 | criterion.step() 74 | 75 | elbo, cat_mean, kls, accuracy = [], [], [], [] 76 | for data, y_train in train_loader: 77 | optimizer.zero_grad() 78 | 79 | x_train = data.to(args.device) 80 | y_train = y_train.to(args.device) 81 | 82 | y_ohe = one_hot_encoding(y_train[:, None], 10, args.device) 83 | y_logits = model(x_train) 84 | 85 | loss, categorical_mean, kl, logsoftmax = criterion(y_logits, 86 | y_ohe) 87 | 88 | pred = torch.argmax(logsoftmax, dim=1) 89 | loss.backward() 90 | 91 | if args.clip_grad > 0: 92 | nn.utils.clip_grad.clip_grad_value_(model.parameters(), 93 | args.clip_grad) 94 | 95 | optimizer.step() 96 | 97 | elbo.append(-loss.item()) 98 | cat_mean.append(categorical_mean.item()) 99 | kls.append(kl.item() if isinstance(kl, torch.Tensor) else kl) 100 | accuracy.append(pred2acc(pred, y_train, args.batch_size)) 101 | 102 | t1 = time() - t0 103 | elbo = np.mean(elbo) 104 | cat_mean = np.mean(cat_mean) 105 | kl = np.mean(kls) 106 | accuracy = np.mean(accuracy) 107 | logger.add(epoch, kl=kl, tr_elbo=elbo, tr_acc=accuracy, tr_ll=cat_mean, 108 | tr_time=t1) 109 | 110 | model.eval() 111 | t_dvi = time() 112 | test_acc_dvi = evaluate(model, test_loader, mode='dvi', args=args) 113 | t_dvi = time() - t_dvi 114 | 115 | if not args.no_mc: 116 | t_mc = time() 117 | test_acc_mcvi = evaluate(model, test_loader, mode='mcvi', args=args) 118 | t_mc = time() - t_mc 119 | logger.add(epoch, te_acc_mcvi=test_acc_mcvi, te_time_mcvi=t_mc) 120 | 121 | test_acc_samples = evaluate(model, test_loader, mode='samples_dvi', 122 | args=args) 123 | logger.add(epoch, te_acc_dvi=test_acc_dvi, 124 | te_acc_samples=test_acc_samples, te_time_dvi=t_dvi) 125 | 126 | logger.iter_info() 127 | logger.save(silent=True) 128 | torch.save(model.state_dict(), logger.checkpoint) 129 | logger.save() 130 | -------------------------------------------------------------------------------- /scripts/lenet-vdo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from time import time 3 | 4 | import numpy as np 5 | 6 | from core.logger import Logger 7 | from core.losses import ClassificationLoss 8 | from core.models import * 9 | from core.utils import load_mnist, one_hot_encoding, evaluate, pred2acc 10 | 11 | np.random.seed(42) 12 | 13 | EPS = 1e-6 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--device', type=int, default=0) 18 | parser.add_argument('--anneal_updates', type=int, default=1) 19 | parser.add_argument('--warmup_updates', type=int, default=0) 20 | parser.add_argument('--mcvi', action='store_true') 21 | parser.add_argument('--mc_samples', default=20, type=int) 22 | parser.add_argument('--clip_grad', type=float, default=0.5) 23 | parser.add_argument('--lr', type=float, default=1e-3) 24 | parser.add_argument('--epochs', type=int, default=150) 25 | parser.add_argument('--zm', action='store_true') 26 | parser.add_argument('--no_mc', action='store_true') 27 | parser.add_argument('--use_samples', action='store_true') 28 | parser.add_argument('--reshape', action='store_true', default=False) 29 | 30 | parser.add_argument('--vdo1', action='store_true') 31 | parser.add_argument('--vdo2', action='store_true') 32 | parser.add_argument('--vdo3', action='store_true') 33 | parser.add_argument('--vdo4', action='store_true') 34 | parser.add_argument('--vdo5', action='store_true') 35 | 36 | fmt = {'kl': '3.3e', 37 | 'tr_elbo': '3.3e', 38 | 'tr_acc': '.4f', 39 | 'tr_ll': '.3f', 40 | 'te_acc_dvi': '.4f', 41 | 'te_acc_mcvi': '.4f', 42 | 'te_acc_samples': '.4f', 43 | 'te_acc_dvi_zm': '.4f', 44 | 'te_acc_mcvi_zm': '.4f', 45 | 'tr_time': '.3f', 46 | 'te_time_dvi': '.3f', 47 | 'te_time_mcvi': '.3f'} 48 | 49 | if __name__ == "__main__": 50 | args = parser.parse_args() 51 | args.device = torch.device( 52 | 'cuda:{}'.format(args.device) if torch.cuda.is_available() else 'cpu') 53 | 54 | model = LeNetVDO(args).to(args.device) 55 | 56 | args.batch_size, args.test_batch_size = 32, 32 57 | train_loader, test_loader = load_mnist(args) 58 | args.data_size = len(train_loader.dataset) 59 | 60 | for layer in model.children(): 61 | i = 0 62 | if hasattr(layer, 'log_alpha'): 63 | fmt.update({'{}log_alpha'.format(i + 1): '3.3e'}) 64 | i += 1 65 | 66 | logger = Logger('lenet-vdo', fmt=fmt) 67 | logger.print(args) 68 | logger.print(model) 69 | 70 | criterion = ClassificationLoss(model, args) 71 | optimizer = torch.optim.Adam([p for p in model.parameters() 72 | if p.requires_grad], lr=args.lr) 73 | 74 | for epoch in range(args.epochs): 75 | t0 = time() 76 | 77 | model.train() 78 | model.set_flag('zero_mean', False) 79 | criterion.step() 80 | 81 | elbo, cat_mean, kls, accuracy = [], [], [], [] 82 | for data, y_train in train_loader: 83 | optimizer.zero_grad() 84 | 85 | x_train = data.to(args.device) 86 | y_train = y_train.to(args.device) 87 | 88 | y_ohe = one_hot_encoding(y_train[:, None], 10, args.device) 89 | y_logits = model(x_train) 90 | 91 | loss, categorical_mean, kl, logsoftmax = criterion(y_logits, 92 | y_ohe) 93 | 94 | pred = torch.argmax(logsoftmax, dim=1) 95 | loss.backward() 96 | 97 | if args.clip_grad > 0: 98 | nn.utils.clip_grad.clip_grad_value_(model.parameters(), 99 | args.clip_grad) 100 | 101 | optimizer.step() 102 | 103 | elbo.append(-loss.item()) 104 | cat_mean.append(categorical_mean.item()) 105 | kls.append(kl.item() if isinstance(kl, torch.Tensor) else kl) 106 | accuracy.append(pred2acc(pred, y_train, args.batch_size)) 107 | 108 | t1 = time() - t0 109 | elbo = np.mean(elbo) 110 | cat_mean = np.mean(cat_mean) 111 | kl = np.mean(kls) 112 | accuracy = np.mean(accuracy) 113 | logger.add(epoch, kl=kl, tr_elbo=elbo, tr_acc=accuracy, tr_ll=cat_mean, 114 | tr_time=t1) 115 | 116 | model.eval() 117 | t_dvi = time() 118 | test_acc_dvi = evaluate(model, test_loader, mode='dvi', args=args) 119 | t_dvi = time() - t_dvi 120 | 121 | if not args.no_mc: 122 | t_mc = time() 123 | test_acc_mcvi = evaluate(model, test_loader, mode='mcvi', args=args) 124 | t_mc = time() - t_mc 125 | logger.add(epoch, te_acc_mcvi=test_acc_mcvi, te_time_mcvi=t_mc) 126 | 127 | test_acc_samples = evaluate(model, test_loader, mode='samples_dvi', 128 | args=args) 129 | logger.add(epoch, te_acc_dvi=test_acc_dvi, 130 | te_acc_samples=test_acc_samples, te_time_dvi=t_dvi) 131 | 132 | test_acc_zero_mean_dvi = evaluate(model, test_loader, mode='dvi', 133 | args=args, zero_mean=True) 134 | logger.add(epoch, te_acc_dvi_zm=test_acc_zero_mean_dvi) 135 | if not args.no_mc: 136 | test_acc_zero_mean_mcvi = evaluate(model, test_loader, 137 | mode='mcvi', 138 | args=args, zero_mean=True) 139 | logger.add(epoch, te_acc_mcvi_zm=test_acc_zero_mean_mcvi) 140 | 141 | i = 0 142 | alphas = {} 143 | for layer in model.children(): 144 | if hasattr(layer, 'log_alpha'): 145 | alphas.update( 146 | {'{}_log_a'.format(i + 1): layer.log_alpha.item()}) 147 | i += 1 148 | logger.add(epoch, **alphas) 149 | 150 | logger.iter_info() 151 | logger.save(silent=True) 152 | torch.save(model.state_dict(), logger.checkpoint) 153 | logger.save() 154 | -------------------------------------------------------------------------------- /scripts/toy_data_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | from core.layers import LinearGaussian, ReluGaussian 8 | from core.losses import ClassificationLoss 9 | from core.utils import generate_classification_data, draw_classification_results 10 | 11 | np.random.seed(42) 12 | 13 | EPS = 1e-6 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--batch_size', type=int, default=500) 18 | parser.add_argument('--mcvi', action='store_true') 19 | parser.add_argument('--hid_size', type=int, default=128) 20 | parser.add_argument('--n_classes', type=int, default=2) 21 | parser.add_argument('--data_size', type=int, default=500) 22 | parser.add_argument('--method', type=str, default='bayes') 23 | parser.add_argument('--device', type=int, default=0) 24 | parser.add_argument('--anneal_updates', type=int, default=1000) 25 | parser.add_argument('--warmup_updates', type=int, default=14000) 26 | parser.add_argument('--test_size', type=int, default=100) 27 | parser.add_argument('--lr', type=float, default=1e-2) 28 | parser.add_argument('--gamma', type=float, default=0.5, 29 | help='lr decrease rate in MultiStepLR scheduler') 30 | parser.add_argument('--epochs', type=int, default=23000) 31 | parser.add_argument('--draw_every', type=int, default=1000) 32 | parser.add_argument('--milestones', nargs='+', type=int, 33 | default=[3000, 5000, 9000, 13000]) 34 | parser.add_argument('--dataset', default='classification') 35 | parser.add_argument('--input_size', default=2, type=int) 36 | parser.add_argument('--mc_samples', default=1, type=int) 37 | 38 | 39 | class Model(nn.Module): 40 | def __init__(self, args): 41 | super().__init__() 42 | hid_size = args.hid_size 43 | self.linear = LinearGaussian(args.input_size, hid_size, certain=True) 44 | self.relu1 = ReluGaussian(hid_size, hid_size) 45 | self.out = ReluGaussian(hid_size, args.n_classes) 46 | 47 | if args.mcvi: 48 | self.mcvi() 49 | 50 | def forward(self, x): 51 | x = self.linear(x) 52 | x = self.relu1(x) 53 | return self.out(x) 54 | 55 | def mcvi(self): 56 | self.linear.mcvi() 57 | self.relu1.mcvi() 58 | self.out.mcvi() 59 | 60 | def determenistic(self): 61 | self.linear.determenistic() 62 | self.relu1.determenistic() 63 | self.out.determenistic() 64 | 65 | 66 | if __name__ == "__main__": 67 | args = parser.parse_args() 68 | args.device = torch.device( 69 | 'cuda:{}'.format(args.device) if torch.cuda.is_available() else 'cpu') 70 | 71 | x_train, y_train, y_onehot_train, x_test, y_test, y_onehot_test = generate_classification_data( 72 | args) 73 | draw_classification_results(x_test, y_test, 'test.png', args) 74 | draw_classification_results(x_train, y_train, 'train.png', args) 75 | 76 | model = Model(args).to(args.device) 77 | criterion = ClassificationLoss(model, args) 78 | 79 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 80 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 81 | args.milestones, 82 | gamma=args.gamma) 83 | 84 | step = 0 85 | 86 | for epoch in range(args.epochs): 87 | step += 1 88 | optimizer.zero_grad() 89 | 90 | y_logits = model(x_train) 91 | loss, categorical_mean, kl, logsoftmax = criterion(y_logits, 92 | y_onehot_train, step) 93 | 94 | pred = torch.argmax(logsoftmax, dim=1) 95 | loss.backward() 96 | 97 | nn.utils.clip_grad.clip_grad_value_(model.parameters(), 0.1) 98 | scheduler.step() 99 | optimizer.step() 100 | 101 | if epoch % args.draw_every == 0: 102 | draw_classification_results(x_train, pred, 103 | 'after_{}_epoch.png'.format(epoch), 104 | args) 105 | 106 | with torch.no_grad(): 107 | y_logits = model(x_train) 108 | _, _, _, logsoftmax = criterion(y_logits, y_onehot_train, step) 109 | pred = torch.argmax(logsoftmax, dim=1) 110 | draw_classification_results(x_train, pred, 'end_train.png', args) 111 | 112 | y_logits = model(x_test) 113 | _, _, _, logsoftmax = criterion(y_logits, y_onehot_test, step) 114 | pred = torch.argmax(logsoftmax, dim=1) 115 | draw_classification_results(x_test, pred, 'end_test.png', args) 116 | -------------------------------------------------------------------------------- /scripts/toy_data_regression.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | 8 | from core.layers import LinearGaussian, ReluGaussian 9 | from core.losses import RegressionLoss 10 | from core.utils import draw_regression_result, get_predictions, \ 11 | generate_regression_data 12 | 13 | np.random.seed(42) 14 | 15 | EPS = 1e-6 16 | 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument('--batch_size', type=int, default=500) 20 | parser.add_argument('--mcvi', action='store_true') 21 | parser.add_argument('--hid_size', type=int, default=128) 22 | parser.add_argument('--heteroskedastic', default=False, action='store_true') 23 | parser.add_argument('--data_size', type=int, default=500) 24 | parser.add_argument('--homo_var', type=float, default=0.35) 25 | parser.add_argument('--homo_log_var_scale', type=float, 26 | default=2 * math.log(0.2)) 27 | parser.add_argument('--method', type=str, default='bayes') 28 | parser.add_argument('--device', type=int, default=0) 29 | parser.add_argument('--anneal_updates', type=int, default=1000) 30 | parser.add_argument('--warmup_updates', type=int, default=14000) 31 | parser.add_argument('--test_size', type=int, default=100) 32 | parser.add_argument('--lr', type=float, default=1e-2) 33 | parser.add_argument('--gamma', type=float, default=0.5, 34 | help='lr decrease rate in MultiStepLR scheduler') 35 | parser.add_argument('--epochs', type=int, default=23000) 36 | parser.add_argument('--draw_every', type=int, default=1000) 37 | parser.add_argument('--milestones', nargs='+', type=int, 38 | default=[3000, 5000, 9000, 13000]) 39 | 40 | 41 | def base_model(x): 42 | return -(x + 0.5) * np.sin(3 * np.pi * x) 43 | 44 | 45 | def noise_model(x, args): 46 | if args.heteroskedastic: 47 | return 1 * (x + 0.5) ** 2 48 | else: 49 | return args.homo_var 50 | 51 | 52 | def sample_data(x, args): 53 | return base_model(x) + np.random.normal(0, noise_model(x, args), 54 | size=x.size).reshape(x.shape) 55 | 56 | 57 | class Model(nn.Module): 58 | def __init__(self, args): 59 | super().__init__() 60 | hid_size = args.hid_size 61 | self.linear = LinearGaussian(1, hid_size, certain=True) 62 | self.relu1 = ReluGaussian(hid_size, hid_size) 63 | if args.heteroskedastic: 64 | self.out = ReluGaussian(hid_size, 2) 65 | else: 66 | self.out = ReluGaussian(hid_size, 1) 67 | 68 | if args.mcvi: 69 | self.mcvi() 70 | 71 | def forward(self, x): 72 | x = self.linear(x) 73 | x = self.relu1(x) 74 | return self.out(x) 75 | 76 | def mcvi(self): 77 | self.linear.mcvi() 78 | self.relu1.mcvi() 79 | self.out.mcvi() 80 | 81 | def determenistic(self): 82 | self.linear.determenistic() 83 | self.relu1.mcvi() 84 | self.out.mcvi() 85 | 86 | 87 | if __name__ == "__main__": 88 | args = parser.parse_args() 89 | args.device = torch.device( 90 | 'cuda:{}'.format(args.device) if torch.cuda.is_available() else 'cpu') 91 | 92 | print(args) 93 | 94 | model = Model(args).to(args.device) 95 | loss = RegressionLoss(model, args) 96 | x_train, y_train, x_test, y_test, toy_data = generate_regression_data(args, 97 | sample_data, 98 | base_model) 99 | 100 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 101 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 102 | args.milestones, 103 | gamma=args.gamma) 104 | 105 | if args.mcvi: 106 | mode = 'mcvi' 107 | else: 108 | mode = 'det' 109 | step = 0 110 | 111 | for epoch in range(args.epochs): 112 | step += 1 113 | optimizer.zero_grad() 114 | 115 | pred = model(x_train) 116 | neg_elbo, ll, kl = loss(pred, y_train, step) 117 | 118 | neg_elbo.backward() 119 | 120 | nn.utils.clip_grad.clip_grad_value_(model.parameters(), 0.1) 121 | scheduler.step() 122 | optimizer.step() 123 | 124 | if epoch % args.draw_every == 0: 125 | print("epoch : {}".format(epoch)) 126 | print("ELBO : {:.4f}\t Likelihood: {:.4f}\t KL: {:.4f}".format( 127 | -neg_elbo.item(), ll.item(), kl.item())) 128 | 129 | if epoch % args.draw_every == 0: 130 | with torch.no_grad(): 131 | predictions = get_predictions(x_train, model, args, args.mcvi) 132 | draw_regression_result(toy_data, 133 | {'mean': base_model, 134 | 'std': lambda x: noise_model(x, 135 | args)}, 136 | predictions=predictions, 137 | name='pics/{}/after_{}.png'.format( 138 | mode, epoch)) 139 | 140 | with torch.no_grad(): 141 | predictions = get_predictions(x_train, model, args, args.mcvi) 142 | draw_regression_result(toy_data, 143 | {'mean': base_model, 144 | 'std': lambda x: noise_model(x, args)}, 145 | predictions=predictions, 146 | name='pics/{}/last.png'.format( 147 | mode)) 148 | 149 | if args.mcvi: 150 | model.determenistic() 151 | with torch.no_grad(): 152 | predictions = get_predictions(x_train, model, args, False) 153 | draw_regression_result(toy_data, 154 | {'mean': base_model, 155 | 'std': lambda x: noise_model(x, args)}, 156 | predictions=predictions, 157 | name='pics/{}/swapped.png'.format( 158 | mode)) 159 | else: 160 | model.mcvi() 161 | with torch.no_grad(): 162 | predictions = get_predictions(x_train, model, args, True) 163 | draw_regression_result(toy_data, 164 | {'mean': base_model, 165 | 'std': lambda x: noise_model(x, args)}, 166 | predictions=predictions, 167 | name='pics/{}/swapped.png'.format( 168 | mode)) 169 | --------------------------------------------------------------------------------