├── pytest.ini ├── run_mypy.sh ├── .gitignore ├── distributions ├── __init__.py ├── distributions.py ├── bernoulli.py ├── gaussian.py └── poisson.py ├── deep_exp_fam ├── __init__.py ├── layers.py └── deep_exponential_family_model.py ├── mypy.ini ├── environment.yml ├── tests ├── test_poisson.py ├── test_gaussian_def.py ├── test_bernoulli.py ├── tester.py ├── gibbs_gaussian_factor_model.py ├── test_fast_bernoulli.py ├── test_score_function.py ├── test_control_variate.py ├── test_minibatches.py ├── test_variational_expectation_maximization.py ├── score_function_gaussian_factor_model.py └── test_multivariate_def.py ├── LICENSE ├── experiments ├── poisson_gaussian_deep_exp_fam_mnist.py └── poisson_gaussian_deep_exp_fam_text.py ├── README.md └── common ├── fit.py ├── util.py └── data.py /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests 3 | -------------------------------------------------------------------------------- /run_mypy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mypy $(find . -name '*.py') 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.cache 2 | *.gitignore 3 | *.mypy_cache 4 | **__pycache__ 5 | *.pyc 6 | 7 | -------------------------------------------------------------------------------- /distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian import * 2 | from .poisson import * 3 | from .bernoulli import * 4 | -------------------------------------------------------------------------------- /deep_exp_fam/__init__.py: -------------------------------------------------------------------------------- 1 | from .deep_exponential_family_model import DeepExponentialFamilyModel 2 | from . import layers 3 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.6 3 | ignore_missing_imports = True 4 | follow_imports = skip 5 | incremental = True 6 | check_untyped_defs = True 7 | warn_unused_ignores = True 8 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: deep_exp_fam 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python>=3.5 6 | - pytest 7 | - pyyaml 8 | - scipy 9 | - ipython 10 | - scikit-learn 11 | - h5py 12 | - pillow 13 | - gensim 14 | - pip: 15 | - mypy 16 | - ipdb 17 | - mxnet>=mxnet-0.12.0b20171016 18 | -------------------------------------------------------------------------------- /tests/test_poisson.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import distributions 3 | import mxnet as mx 4 | import scipy.stats 5 | import scipy.special 6 | 7 | from mxnet import nd 8 | 9 | mx.random.seed(13343) 10 | 11 | 12 | def test_poisson_sampling(): 13 | rate = 5. 14 | n_samples = 10000 15 | samples = distributions.Poisson(nd.array([rate])).sample(n_samples) 16 | mean = nd.mean(samples).asnumpy() 17 | np.testing.assert_allclose(mean, rate, rtol=1e-2) 18 | 19 | 20 | def test_poisson_log_prob(): 21 | rate = 1. 22 | data = [2, 5, 0, 10, 4] 23 | np_log_prob = scipy.stats.poisson.logpmf(np.array(data), mu=np.array(rate)) 24 | p = distributions.Poisson(nd.array([rate])) 25 | mx_log_prob = p.log_prob(nd.array(data)).asnumpy() 26 | np.testing.assert_allclose(mx_log_prob, np_log_prob) 27 | -------------------------------------------------------------------------------- /tests/test_gaussian_def.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tester 3 | 4 | 5 | def test_latent_layer(): 6 | """Test matching q(z) to p(z) where p is Gaussian.""" 7 | mean = 5. 8 | 9 | config = """ 10 | n_iterations: 50 11 | learning_rate: 0.1 12 | gradient: 13 | estimator: pathwise 14 | n_samples: 1 15 | batch_size: 1 16 | layer_1: 17 | latent_distribution: gaussian 18 | size: 1 19 | p_z_mean: {} 20 | p_z_variance: 1. 21 | """.format(mean) 22 | 23 | def test_posterior_predictive(sample: np.array) -> None: 24 | print('posterior predictive mean:', sample) 25 | print('prior mean:', mean) 26 | np.testing.assert_allclose(sample, mean, rtol=1e-1) 27 | 28 | tester.test(config, data=np.array([np.nan]), 29 | test_fn=test_posterior_predictive) 30 | -------------------------------------------------------------------------------- /tests/test_bernoulli.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import distributions 3 | import scipy.stats 4 | import mxnet as mx 5 | 6 | 7 | mx.random.seed(13343) 8 | 9 | 10 | def test_bernoulli_sampling(): 11 | logits = 0.232 12 | n_samples = 10000 13 | p = distributions.Bernoulli(mx.nd.array([logits])) 14 | samples = p.sample(n_samples) 15 | mean = mx.nd.mean(samples).asnumpy() 16 | print('sampling mean, mean', mean, p.mean.asnumpy()) 17 | np.testing.assert_allclose(mean, p.mean.asnumpy(), rtol=1e-2) 18 | 19 | 20 | def test_bernoulli_log_prob(): 21 | logits = 0.384 22 | data = [0, 1, 0, 0, 1] 23 | p = distributions.Bernoulli(mx.nd.array([logits])) 24 | np_log_prob = scipy.stats.bernoulli.logpmf( 25 | np.array(data), p=p.mean.asnumpy()) 26 | mx_log_prob = p.log_prob(mx.nd.array(data)).asnumpy() 27 | np.testing.assert_allclose(mx_log_prob, np_log_prob) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Jaan Altosaar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/tester.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import numpy as np 3 | import yaml 4 | 5 | from common import fit 6 | from mxnet import gluon 7 | from mxnet import nd 8 | from typing import Callable 9 | 10 | import deep_exp_fam 11 | 12 | 13 | def test(config_yaml: str, data: np.array, test_fn: Callable = None): 14 | np.random.seed(23423) 15 | 16 | def get_data_iter(batch_size, shuffle): 17 | dataset = gluon.data.ArrayDataset( 18 | data.astype(np.float32), range(len(data))) 19 | return gluon.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) 20 | config = yaml.load(config_yaml) 21 | data_iter = get_data_iter(config['gradient']['batch_size'], shuffle=True) 22 | my_model = fit.fit(config, data_iter) 23 | data_iter = get_data_iter(config['gradient']['batch_size'], shuffle=False) 24 | n_samples_stats = 10 25 | for data_batch in data_iter: 26 | _, _, sample = my_model(data_batch) 27 | tmp_sample = nd.zeros_like(sample) 28 | for _ in range(n_samples_stats): 29 | _, _, sample = my_model(data_batch) 30 | tmp_sample += sample 31 | tmp_sample /= n_samples_stats 32 | if tmp_sample.ndim == 3: 33 | tmp_sample = nd.mean(tmp_sample, 0, keepdims=True) 34 | tmp_sample = tmp_sample.asnumpy() 35 | if len(data) == 1: 36 | tmp_sample = tmp_sample.reshape((1, -1)) 37 | test_fn(tmp_sample) 38 | else: 39 | test_fn(tmp_sample, data_batch[0].asnumpy()) 40 | -------------------------------------------------------------------------------- /tests/gibbs_gaussian_factor_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | """Run Gibbs sampling in the model: 3 | 4 | z ~ N(0, 1) 5 | w ~ N(0, 1) 6 | x ~ N(zw, 1) 7 | """ 8 | np.random.seed(1323) 9 | 10 | 11 | def sample_latent(mu_a, var_a, b_sample, mu_b, var_b, x, var_x): 12 | """Sample a latent variable w or z.""" 13 | mean = ((1. / var_b * x * b_sample + 1. / var_a * mu_a) / 14 | (1. / var_x * np.square(b_sample) + 1. / var_a)) 15 | var = 1. / (1. / var_x * np.square(b_sample) + 1. / var_a) 16 | return mean + var * np.random.normal() 17 | 18 | 19 | def normal_log_prob(x, mean, var): 20 | return -0.5 * np.log(2. * np.pi * var) - 0.5 * np.square(x - mean) / var 21 | 22 | 23 | def run_gibbs_sampling(): 24 | mu_z = 0. 25 | var_z = 1. 26 | mu_w = 0. 27 | var_w = 1. 28 | x = 30.3 29 | var_x = 1. 30 | 31 | w_sample = mu_w + var_w * np.random.normal() 32 | for step in range(100): 33 | z_sample = sample_latent(mu_z, var_z, w_sample, mu_w, var_w, x, var_x) 34 | w_sample = sample_latent(mu_w, var_w, z_sample, mu_z, var_z, x, var_x) 35 | log_prob = (normal_log_prob(x, z_sample * w_sample, var_x) + 36 | normal_log_prob(w_sample, mu_w, var_w) + 37 | normal_log_prob(z_sample, mu_z, var_z)) 38 | if step % 10 == 0: 39 | print('step:', step) 40 | print('log joint', log_prob) 41 | print('z_sample:', z_sample) 42 | print('w_sample:', w_sample) 43 | x_sample = z_sample * w_sample + var_x * np.random.normal() 44 | print('posterior predictive mean:', z_sample * w_sample) 45 | 46 | 47 | if __name__ == '__main__': 48 | run_gibbs_sampling() 49 | -------------------------------------------------------------------------------- /experiments/poisson_gaussian_deep_exp_fam_mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import yaml 4 | import os 5 | import time 6 | import mxnet as mx 7 | 8 | from mxnet import nd 9 | from common import fit 10 | 11 | 12 | if __name__ == '__main__': 13 | path = os.path.join( 14 | os.environ['LOG'], '/mnist/' + time.strftime("%Y-%m-%d")) 15 | if not os.path.exists(path): 16 | os.makedirs(path) 17 | 18 | config = yaml.load( 19 | """ 20 | dir: {} 21 | clear_dir: false 22 | use_gpu: false 23 | learning_rate: 0.01 24 | n_iterations: 100000 25 | print_every: 1 26 | gradient: 27 | estimator: score_function 28 | n_samples: 32 29 | batch_size: 10 30 | layer_1: 31 | latent_distribution: poisson 32 | # weight_distribution: point_mass 33 | size: 10 34 | layer_0: 35 | weight_distribution: point_mass 36 | data_distribution: bernoulli 37 | data_size: 784 38 | """.format(path)) 39 | 40 | if config['clear_dir']: 41 | for f in os.listdir(config['dir']): 42 | os.remove(os.path.join(config['dir'], f)) 43 | 44 | # hdf5 file from: 45 | # https://github.com/altosaar/proximity_vi/blob/master/get_binary_mnist.py 46 | data_path = os.path.join(os.environ['DAT'], 'binarized_mnist.hdf5') 47 | 48 | f = h5py.File(data_path, 'r') 49 | raw_data = f['train'][:][0:10] 50 | f.close() 51 | 52 | train_data = mx.io.NDArrayIter( 53 | data={'data': nd.array(raw_data)}, 54 | label={'label': range(len(raw_data)) * np.ones((len(raw_data),))}, 55 | batch_size=config['gradient']['batch_size'], 56 | shuffle=True) 57 | 58 | fit.fit(config, train_data) 59 | -------------------------------------------------------------------------------- /tests/test_fast_bernoulli.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import distributions 4 | import scipy.stats 5 | import scipy.special 6 | import mxnet as mx 7 | from mxnet import nd 8 | 9 | 10 | mx.random.seed(13343) 11 | np.random.seed(2324) 12 | 13 | 14 | def test_bernoulli_sampling(): 15 | n_samples = 10000 16 | K = 10 # num factors 17 | C = 2 # num classes 18 | # latent variable is of size [n_samples, batch_size, latent_size] 19 | positive_latent = nd.ones((1, 1, K)) * 0.01 20 | weight = nd.ones((K, C)) * 0.1 21 | bias = nd.ones(C) * 0.01 22 | p = distributions.FastBernoulli( 23 | positive_latent=positive_latent, weight=weight, bias=bias) 24 | samples = p.sample(n_samples) 25 | print(samples.shape) 26 | mean = nd.mean(samples, 0).asnumpy() 27 | print('sampling mean, mean', mean, p.mean.asnumpy()) 28 | np.testing.assert_allclose(mean, p.mean.asnumpy(), rtol=1e-1) 29 | 30 | 31 | def test_bernoulli_log_prob(): 32 | K = 10 # num factors 33 | C = 100 # num classes 34 | positive_latent = nd.ones((1, 1, K)) * nd.array(np.random.rand(K)) 35 | weight = nd.ones((K, C)) * nd.array(np.random.rand(K, C)) 36 | bias = nd.ones(C) * 0.01 37 | data = np.random.binomial(n=1, p=0.1, size=C) 38 | assert np.sum(data) > 0 39 | nonzero_idx = np.nonzero(data)[0] 40 | p = distributions.FastBernoulli( 41 | positive_latent=positive_latent, weight=weight, bias=bias) 42 | np_log_prob_sum = scipy.stats.bernoulli.logpmf( 43 | np.array(data), p=p.mean.asnumpy()).sum() 44 | mx_log_prob_sum = p.log_prob_sum( 45 | nonzero_index=nd.array(nonzero_idx)).asnumpy() 46 | print('mx log prob sum, np log prob sum', mx_log_prob_sum, np_log_prob_sum) 47 | np.testing.assert_allclose(mx_log_prob_sum, np_log_prob_sum, rtol=1e-3) 48 | -------------------------------------------------------------------------------- /tests/test_score_function.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import mxnet as mx 3 | import numpy as np 4 | import yaml 5 | 6 | from mxnet import gluon 7 | from mxnet import nd 8 | from mxnet import autograd 9 | 10 | import deep_exp_fam 11 | import distributions 12 | import tester 13 | 14 | 15 | def test_latent_gaussian_layer(): 16 | """Test matching q(z) to p(z) where p is Gaussian.""" 17 | mean = 3.9 18 | 19 | config = """ 20 | n_iterations: 100 21 | learning_rate: 0.1 22 | gradient: 23 | estimator: score_function 24 | n_samples: 16 25 | batch_size: 1 26 | layer_1: 27 | latent_distribution: gaussian 28 | p_z_mean: {} 29 | p_z_variance: 1. 30 | size: 1 31 | """.format(mean) 32 | 33 | def test_posterior_predictive(sample: np.array) -> None: 34 | print('posterior predictive sample:', sample) 35 | print('prior mean:', mean) 36 | np.testing.assert_allclose(sample, mean, rtol=1e-1) 37 | 38 | tester.test(config, data=np.array([np.nan]), 39 | test_fn=test_posterior_predictive) 40 | 41 | 42 | def test_latent_poisson_layer(): 43 | """Test matching q(z) to p(z) where p is Poisson.""" 44 | mean = 5. 45 | 46 | config = """ 47 | n_iterations: 100 48 | learning_rate: 0.1 49 | gradient: 50 | estimator: score_function 51 | n_samples: 16 52 | batch_size: 1 53 | layer_1: 54 | latent_distribution: poisson 55 | size: 1 56 | p_z_mean: {} 57 | """.format(mean) 58 | 59 | def test_posterior_predictive(sample: np.array) -> None: 60 | print('posterior predictive sample:', sample) 61 | print('prior mean:', mean) 62 | np.testing.assert_allclose(sample, mean, rtol=1e-1) 63 | 64 | tester.test(config, data=np.array([np.nan]), 65 | test_fn=test_posterior_predictive) 66 | -------------------------------------------------------------------------------- /distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing 3 | 4 | from mxnet import nd 5 | from mxnet import gluon 6 | 7 | 8 | class BaseDistribution(gluon.Block, metaclass=abc.ABCMeta): 9 | def __init__(self): 10 | super(BaseDistribution, self).__init__() 11 | 12 | @abc.abstractproperty 13 | def is_reparam(self): 14 | pass 15 | 16 | @abc.abstractmethod 17 | def sample(self, n_samples: int) -> nd.NDArray: 18 | pass 19 | 20 | @abc.abstractmethod 21 | def log_prob(self, x: nd.NDArray) -> nd.NDArray: 22 | pass 23 | 24 | def get_param_not_repeated(self, name): 25 | """Return a parameter without any repetitions.""" 26 | param = getattr(self, name) 27 | if isinstance(param, gluon.Parameter): 28 | res = param.data() 29 | if hasattr(param, 'link_function'): 30 | return param.link_function(res) 31 | else: 32 | return res 33 | else: 34 | return param 35 | 36 | def get_param_maybe_repeated(self, name): 37 | """Return repeated parameter if it has been repeated to get jacobians.""" 38 | if hasattr(self, name + '_repeated'): 39 | param = getattr(self, name + '_repeated') 40 | else: 41 | param = getattr(self, name) 42 | 43 | if isinstance(param, nd.NDArray): 44 | return param 45 | elif isinstance(param, gluon.Parameter): 46 | if hasattr(param, 'repeated'): 47 | res = param.repeated 48 | else: 49 | res = param.data() 50 | 51 | if hasattr(param, 'link_function'): 52 | return param.link_function(res) 53 | else: 54 | return res 55 | else: 56 | raise ValueError('Parameter has invalid type: %s' % type(param)) 57 | 58 | def forward(self): 59 | raise ValueError('Ambiguous: Need to call log_prob or sample!') 60 | -------------------------------------------------------------------------------- /experiments/poisson_gaussian_deep_exp_fam_text.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gensim 3 | import yaml 4 | import os 5 | import time 6 | from mxnet import nd 7 | from mxnet import gluon 8 | from common import fit 9 | from common import util 10 | from common import data 11 | 12 | 13 | if __name__ == '__main__': 14 | path = os.path.join(os.environ['LOG'], 'text/' + time.strftime("%Y-%m-%d")) 15 | 16 | if not os.path.exists(path): 17 | os.makedirs(path) 18 | 19 | cfg = yaml.load( 20 | """ 21 | dir: {} 22 | clear_dir: false 23 | # IMPORTANT 24 | use_gpu: true 25 | learning_rate: 0.01 26 | n_iterations: 5000000 27 | print_every: 10000 28 | gradient: 29 | estimator: score_function 30 | n_samples: 32 31 | batch_size: 64 32 | layer_1: 33 | latent_distribution: poisson 34 | q_z_mean: 3. 35 | size: 100 36 | layer_0: 37 | weight_distribution: point_mass 38 | data_distribution: poisson 39 | data_size: null 40 | """.format(path)) 41 | 42 | if cfg['clear_dir']: 43 | for f in os.listdir(cfg['dir']): 44 | os.remove(os.path.join(cfg['dir'], f)) 45 | 46 | with open(os.path.join(cfg['dir'], 'config.yml'), 'w') as f: 47 | yaml.dump(cfg, f, default_flow_style=False) 48 | 49 | util.log_to_file(os.path.join(cfg['dir'], 'train.log')) 50 | 51 | fname = os.path.join(os.environ['DAT'], 'science/documents_train.dat') 52 | fname_vocab = os.path.join(os.environ['DAT'], 'science/VOCAB-TFIDF-1000.dat') 53 | 54 | corpus = gensim.corpora.bleicorpus.BleiCorpus(fname, fname_vocab) 55 | cfg['layer_0']['data_size'] = len(corpus.id2word) 56 | docs = [doc for doc in corpus if len(doc) > 0] 57 | dataset = gluon.data.ArrayDataset(data=docs, label=range(len(docs))) 58 | train_data = data.DocumentDataLoader(dataset=dataset, 59 | id2word=corpus.id2word, 60 | batch_size=cfg['gradient']['batch_size'], 61 | last_batch='discard', 62 | shuffle=True) 63 | 64 | fit.fit(cfg, train_data) 65 | -------------------------------------------------------------------------------- /tests/test_control_variate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import mxnet as mx 3 | import numpy as np 4 | import yaml 5 | 6 | from mxnet import gluon 7 | from mxnet import nd 8 | from mxnet import autograd 9 | 10 | from deep_exp_fam import deep_exponential_family_model as lib 11 | 12 | np.random.seed(32423) 13 | mx.random.seed(3242) 14 | 15 | 16 | def test_linear_time_estimator(): 17 | """Test the control variate estimator with Rajesh's numpy implementation.""" 18 | init_mean = 3.9 19 | n_samples = 256 20 | mean = nd.ones(n_samples) * init_mean 21 | mean.attach_grad() 22 | variance = nd.array([1.]) 23 | sample = nd.stop_gradient( 24 | mean + variance * nd.random_normal(shape=(n_samples,))) 25 | with autograd.record(): 26 | log_prob = (-0.5 * nd.log(2. * np.pi * variance) 27 | - 0.5 * nd.square(sample - mean) / variance) 28 | log_prob.backward() 29 | f = nd.square(sample) - 3. 30 | h = mean.grad 31 | grad = lib._leave_one_out_gradient_estimator(h, f) 32 | fast_grad = lib._leave_one_out_gradient_estimator(h, f, zero_mean_h=True) 33 | np_grad = np.mean(leave_one_out_control_variates(h.asnumpy(), f.asnumpy())) 34 | np.testing.assert_allclose(grad.asnumpy(), np_grad, rtol=1e-3) 35 | np.testing.assert_allclose(fast_grad.asnumpy(), np_grad, rtol=1e-2) 36 | 37 | 38 | def test_held_out_covariance(): 39 | """Test leave-one-out covariance estimation.""" 40 | x = np.random.rand(10) 41 | y = np.random.rand(10) 42 | cov = lib._held_out_covariance(nd.array(x), nd.array(y)) 43 | np_cov = held_out_cov(x, y) 44 | np.testing.assert_allclose(cov.asnumpy(), np_cov, rtol=1e-2) 45 | 46 | 47 | def leave_one_out_control_variates(score, f): 48 | held_out_covariance = held_out_cov(score, f * score) 49 | held_out_variance = held_out_cov(score, score) 50 | optimal_a = held_out_covariance / held_out_variance 51 | grad = score * (f - optimal_a) 52 | return grad 53 | 54 | 55 | def held_out_cov(x, y): 56 | n = len(x) 57 | C = np.cov(x, y)[0, 1] # Slightly wasteful 58 | meanx = np.mean(x) 59 | meany = np.mean(y) 60 | C *= (n - 1) 61 | 62 | meanx_ho = (meanx - x / n) / (1 - 1.0 / n) 63 | meany_ho = (meany - y / n) / (1 - 1.0 / n) 64 | C -= (x - meanx_ho) * (y - meany_ho) 65 | C /= (n - 2) 66 | return C 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Deep exponential families 2 | 3 | This is an implementation of deep exponential families in MXNet/Gluon. DEFs are described in https://arxiv.org/abs/1411.2581 4 | 5 | I found it much easier to implement this in an imperative / dynamic graph library like mxnet than in autodifferentiation libraries that only support static computation graphs. 6 | 7 | Currently the code only implements a point-mass distributions for the weights and biases of each layer in the DEF (these parameters are learned using variational expectation-maximization). It should be straightforward to extend this to other distributions. 8 | 9 | The gradients are computed with either the score function estimator or the pathwise (reparameterization trick) estimator. For score function gradient estimators, we use the optimal control variate scaling described in [black box variational inference](https://arxiv.org/abs/1401.0118). 10 | 11 | The code takes lots of inspiration from the official deep exponential families [codebase](https://github.com/blei-lab/deep-exponential-families) and the gluon examples in mxnet. 12 | 13 | ### Example 14 | 15 | Train a Poisson deep exponential family model on a large collection of science articles (in the LDA-C format): 16 | ``` 17 | PYTHONPATH=. python experiments/poisson_gaussian_deep_exp_fam_text.py 18 | ``` 19 | This periodically prints out the latent factors (dimensions of the latent variable), and the weight associated with each. For example, a dimension captures documents about DNA: 20 | ``` 21 | 0.246 fig 22 | -0.358 dna 23 | -0.366 protein 24 | -0.372 cells 25 | -0.430 cell 26 | -0.722 gene 27 | -0.970 binding 28 | -1.010 two 29 | -1.026 sequence 30 | -1.100 proteins 31 | ``` 32 | 33 | To train a Poisson deep exponential family model on the MNIST dataset: 34 | ``` 35 | PYTHONPATH=. python experiments/poisson_gaussian_deep_exp_fam_mnist.py 36 | ``` 37 | 38 | Also see examples in `tests/` folder. 39 | 40 | ### Requirements 41 | Install requirements with [anaconda](https://conda.io/docs/user-guide/install/index.html): 42 | ``` 43 | conda env create -f environment.yml 44 | source activate deep_exp_fam 45 | ``` 46 | 47 | ### Testing 48 | Run `PYTHONPATH=. pytest` for unit tests and `mypy $(find . -name '*.py')` for static type-checking. 49 | 50 | ### TODO: 51 | * figure out a cleaner way to do per-sample gradients -- bug tracker: https://github.com/apache/incubator-mxnet/issues/7987 (right now, parameters are repeated in deep_exp_fam.DeepExponentialFamilyModel class and require annoying processing) 52 | * add support for priors on the weights 53 | -------------------------------------------------------------------------------- /common/fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import time 4 | import mxnet as mx 5 | import numpy as np 6 | import deep_exp_fam 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | from mxnet import gluon 12 | from mxnet import nd 13 | from deep_exp_fam import DeepExponentialFamilyModel 14 | from deep_exp_fam import layers 15 | from common import util 16 | from common import data 17 | 18 | 19 | def fit(cfg: dict, 20 | train_data: gluon.data.DataLoader) -> DeepExponentialFamilyModel: 21 | """Fit a deep exponential family model to data.""" 22 | mx.random.seed(32429) 23 | np.random.seed(423323) 24 | ctx = [mx.gpu(0)] if ('use_gpu' in cfg and cfg['use_gpu']) else [mx.cpu()] 25 | with mx.Context(ctx[0]): 26 | my_model = DeepExponentialFamilyModel( 27 | n_data=len(train_data._dataset), gradient_config=cfg['gradient']) 28 | with my_model.name_scope(): 29 | layer_names = sorted([key for key in cfg if 'layer' in key])[::-1] 30 | for name in layer_names: 31 | if name != 'layer_0': 32 | my_model.add(layers.LatentLayer, **cfg[name]) 33 | elif name == 'layer_0': 34 | my_model.add(layers.ObservationLayer, **cfg[name]) 35 | params = my_model.collect_params() 36 | logger.info(params) 37 | latest_step = 0 38 | if 'dir' in cfg: 39 | latest_params, latest_step = util.latest_checkpoint(cfg['dir']) 40 | if latest_params is not None: 41 | my_model.load_params(latest_params, ctx) 42 | step = latest_step if latest_step is not None else 0 43 | params.initialize(ctx=ctx) 44 | my_model.maybe_attach_repeated_params() 45 | trainer = mx.gluon.Trainer( 46 | params, 'rmsprop', {'learning_rate': cfg['learning_rate']}) 47 | print_every = cfg['print_every'] if 'print_every' in cfg else 100 48 | while step <= cfg['n_iterations'] + 1: 49 | for data_batch in train_data: 50 | with mx.autograd.record(): 51 | log_q_sum, elbo, sample = my_model(data_batch, False) 52 | my_model.compute_gradients(elbo, data_batch, log_q_sum) 53 | if step % print_every == 0: 54 | elbo = np.mean(np.mean(elbo.asnumpy(), 0), 0) 55 | if 'layer_0' in cfg and hasattr(train_data, 'id2word'): 56 | w = params.get('observationlayer0_point_mass_weight') 57 | data.print_top_words(w, train_data.id2word) 58 | logger.info('t %d\telbo: %.3e' % (step, elbo)) 59 | if np.isnan(elbo): 60 | raise ValueError('ELBO hit nan!') 61 | if 'dir' in cfg: 62 | param_file = 'my_model.params-iteration-%d' % step 63 | my_model.save_params(os.path.join(cfg['dir'], param_file)) 64 | trainer.step(1) 65 | step += 1 66 | if step >= cfg['n_iterations']: 67 | break 68 | return my_model 69 | -------------------------------------------------------------------------------- /tests/test_minibatches.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tester 3 | 4 | 5 | def test_gaussian(): 6 | config = """ 7 | learning_rate: 0.1 8 | n_iterations: 300 9 | gradient: 10 | estimator: pathwise 11 | n_samples: 1 12 | batch_size: 1 13 | layer_1: 14 | latent_distribution: gaussian 15 | size: 1 16 | layer_0: 17 | weight_distribution: point_mass 18 | data_size: 1 19 | """ 20 | data = np.array([[20.3], [-30.3], [15.]]) 21 | 22 | def test_posterior_predictive(sample: np.array, data: np.array) -> None: 23 | print('----') 24 | print('data:', data) 25 | print('posterior predictive sample:', sample) 26 | np.testing.assert_allclose(sample, np.expand_dims(data, 0), rtol=0.2) 27 | 28 | tester.test(config, data=data, test_fn=test_posterior_predictive) 29 | 30 | 31 | def test_gaussian_score(): 32 | config = """ 33 | learning_rate: 0.05 34 | n_iterations: 100 35 | gradient: 36 | estimator: score_function 37 | n_samples: 16 38 | batch_size: 3 39 | layer_1: 40 | latent_distribution: gaussian 41 | size: 7 42 | layer_0: 43 | weight_distribution: point_mass 44 | data_size: 1 45 | """ 46 | data = np.array([[20.3], [30.3], [15.]]) 47 | 48 | def test_posterior_predictive(sample: np.array, data: np.array) -> None: 49 | print('----') 50 | print('data:', data) 51 | print('posterior predictive sample:', sample) 52 | np.testing.assert_allclose(sample, np.expand_dims(data, 0), rtol=0.2) 53 | 54 | tester.test(config, data=data, test_fn=test_posterior_predictive) 55 | 56 | 57 | def test_gaussian_score_multivariate_data(): 58 | config = """ 59 | learning_rate: 0.1 60 | n_iterations: 100 61 | gradient: 62 | estimator: score_function 63 | n_samples: 16 64 | batch_size: 3 65 | layer_1: 66 | latent_distribution: gaussian 67 | size: 7 68 | layer_0: 69 | weight_distribution: point_mass 70 | data_size: 3 71 | """ 72 | data = np.array( 73 | [[20.3, -30.3, 15.3], [-30.3, -40.4, 23.5], [15., -20.3, 28.9]]) 74 | 75 | def test_posterior_predictive(sample: np.array, data: np.array) -> None: 76 | print('----') 77 | print('data:', data) 78 | print('posterior predictive sample:', sample) 79 | np.testing.assert_allclose(sample, np.expand_dims(data, 0), rtol=0.2) 80 | 81 | tester.test(config, data=data, test_fn=test_posterior_predictive) 82 | 83 | 84 | def test_poisson_multivariate_data(): 85 | config = """ 86 | learning_rate: 0.05 87 | n_iterations: 200 88 | gradient: 89 | estimator: score_function 90 | n_samples: 64 91 | batch_size: 3 92 | layer_1: 93 | latent_distribution: poisson 94 | size: 3 95 | layer_0: 96 | weight_distribution: point_mass 97 | data_size: 2 98 | """ 99 | data = np.array( 100 | [[-20.3, 30.3], [-30.3, 40.4], [-10.5, 20.3]]) 101 | 102 | def test_posterior_predictive(sample: np.array, data: np.array) -> None: 103 | print('----') 104 | print('data:', data) 105 | print('posterior predictive sample:', sample) 106 | np.testing.assert_allclose(sample, np.expand_dims(data, 0), rtol=0.5) 107 | 108 | tester.test(config, data=data, test_fn=test_posterior_predictive) 109 | -------------------------------------------------------------------------------- /tests/test_variational_expectation_maximization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tester 3 | 4 | 5 | def test_one_layer_gaussian(): 6 | """Test variational EM with pathwise gradients.""" 7 | config = """ 8 | learning_rate: 0.1 9 | n_iterations: 100 10 | gradient: 11 | estimator: pathwise 12 | n_samples: 1 13 | batch_size: 1 14 | layer_1: 15 | latent_distribution: gaussian 16 | p_z_variance: 1. 17 | size: 1 18 | layer_0: 19 | weight_distribution: point_mass 20 | p_w_variance: 1. 21 | data_size: 1 22 | """ 23 | data = np.array([[3.3]]) 24 | 25 | def test_posterior_predictive(sample: np.array) -> None: 26 | print('data:', data) 27 | print('posterior predictive sample:', sample) 28 | np.testing.assert_allclose(sample, data, rtol=0.3) 29 | 30 | tester.test(config, data=data, test_fn=test_posterior_predictive) 31 | 32 | 33 | def test_one_layer_gaussian_score(): 34 | """Test variational EM with score function gradients.""" 35 | config = """ 36 | learning_rate: 0.1 37 | n_iterations: 100 38 | print_every: 100 39 | gradient: 40 | estimator: score_function 41 | n_samples: 32 42 | batch_size: 1 43 | layer_1: 44 | latent_distribution: gaussian 45 | size: 1 46 | layer_0: 47 | weight_distribution: point_mass 48 | data_size: 1 49 | """ 50 | data = np.array([[30.3]]) 51 | 52 | def test_posterior_predictive(sample: np.array) -> None: 53 | print('data:', data) 54 | print('posterior predictive sample:', sample) 55 | np.testing.assert_allclose(sample, data, rtol=1e-1) 56 | 57 | tester.test(config, data=data, test_fn=test_posterior_predictive) 58 | 59 | 60 | def test_one_layer_poisson_score(): 61 | """Test variational EM with score function gradients and poisson latents.""" 62 | config = """ 63 | learning_rate: 0.1 64 | n_iterations: 100 65 | print_every: 100 66 | gradient: 67 | estimator: score_function 68 | n_samples: 32 69 | batch_size: 1 70 | layer_1: 71 | latent_distribution: poisson 72 | size: 1 73 | layer_0: 74 | weight_distribution: point_mass 75 | data_size: 1 76 | """ 77 | data = np.array([[30.3]]) 78 | 79 | def test_posterior_predictive(sample: np.array) -> None: 80 | print('data:', data) 81 | print('posterior predictive sample:', sample) 82 | np.testing.assert_allclose(sample, data, rtol=0.3) 83 | 84 | tester.test(config, data=data, test_fn=test_posterior_predictive) 85 | 86 | 87 | def test_two_layer_gaussian_score(): 88 | """Test variational EM with score function gradients and poisson latents.""" 89 | config = """ 90 | learning_rate: 0.1 91 | n_iterations: 100 92 | gradient: 93 | estimator: score_function 94 | n_samples: 32 95 | batch_size: 1 96 | layer_2: 97 | latent_distribution: gaussian 98 | size: 1 99 | layer_1: 100 | latent_distribution: gaussian 101 | weight_distribution: point_mass 102 | size: 1 103 | layer_0: 104 | weight_distribution: point_mass 105 | data_size: 1 106 | """ 107 | data = np.array([[8.5]]) 108 | 109 | def test_posterior_predictive(sample: np.array) -> None: 110 | print('data:', data) 111 | print('posterior predictive sample:', sample) 112 | np.testing.assert_allclose(sample, data, rtol=1e-1) 113 | 114 | tester.test(config, data=data, test_fn=test_posterior_predictive) 115 | -------------------------------------------------------------------------------- /tests/score_function_gaussian_factor_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | 5 | def normal_log_prob(x, mean, var): 6 | return -0.5 * np.log(2. * np.pi * var) - 0.5 * np.square(x - mean) / var 7 | 8 | 9 | def grad_mean_normal_log_prob(x, mean, var): 10 | return (x - mean) / var 11 | 12 | 13 | def grad_var_normal_log_prob(x, mean, var): 14 | return -0.5 / var + 0.5 * np.square(x - mean) / np.square(var) 15 | 16 | 17 | def softplus(x): 18 | return np.log(1. + np.exp(x)) 19 | 20 | 21 | def grad_softplus(x): 22 | return np.exp(x) / (1. + np.exp(x)) 23 | 24 | 25 | def inv_softplus(x): 26 | return np.log(np.exp(x) - 1.) 27 | 28 | 29 | def grad_normal_log_prob(x, mean, var_arg): 30 | grad_var_arg = (grad_var_normal_log_prob(x, mean, softplus(var_arg)) 31 | * grad_softplus(var_arg)) 32 | grad_mean = grad_mean_normal_log_prob(x, mean, softplus(var_arg)) 33 | return grad_mean, grad_var_arg 34 | 35 | 36 | def normal_sample(mean, var_arg, n_samples): 37 | return mean + np.sqrt(softplus(var_arg)) * np.random.normal(size=n_samples) 38 | 39 | 40 | def bbvi(): 41 | p_z_mean = 0. 42 | p_z_var = 1. 43 | q_z_mean = 0. 44 | q_z_var_arg = inv_softplus(1.) 45 | q_w_mean = 0. 46 | q_w_var_arg = inv_softplus(1.) 47 | p_w_mean = 0. 48 | p_w_var = 1. 49 | p_x_var = 1. 50 | x = 30. 51 | n_samples = 32 52 | learning_rate = 1e-5 53 | 54 | for i in range(500000): 55 | t0 = time.time() 56 | z_sample = normal_sample(q_z_mean, q_z_var_arg, n_samples) 57 | log_q_z = normal_log_prob(z_sample, q_z_mean, softplus(q_z_var_arg)) 58 | log_p_z = normal_log_prob(z_sample, p_z_mean, p_z_var) 59 | w_sample = normal_sample(q_w_mean, q_w_var_arg, n_samples) 60 | log_q_w = normal_log_prob(w_sample, q_w_mean, softplus(q_w_var_arg)) 61 | log_p_w = normal_log_prob(w_sample, p_w_mean, p_w_var) 62 | p_x_mean = w_sample * z_sample 63 | log_p_x = normal_log_prob(x, p_x_mean, p_x_var) 64 | elbo = log_p_x + log_p_z - log_q_z + log_p_w - log_q_w 65 | 66 | # gradients 67 | score_q_z_mean, score_q_z_var_arg = grad_normal_log_prob( 68 | z_sample, q_z_mean, q_z_var_arg) 69 | score_q_w_mean, score_q_w_var_arg = grad_normal_log_prob( 70 | w_sample, q_w_mean, q_w_var_arg) 71 | 72 | if i % 1000 == 0: 73 | print('i: %d\t elbo: %.3f\tposterior predicitve mean, data: %.3f, %.3f' % 74 | (i, np.mean(elbo), np.mean(p_x_mean), x)) 75 | print('q_w_mean: %.3f\tq_z_mean: %.3f' % (q_w_mean, q_z_mean)) 76 | print('q_w_var: %.3f\tq_z_var: %.3f' % 77 | (softplus(q_w_var_arg), softplus(q_z_var_arg))) 78 | 79 | # updates 80 | q_z_mean += learning_rate * bbvi_gradient(score_q_z_mean, elbo) 81 | q_z_var_arg += learning_rate * bbvi_gradient(score_q_z_var_arg, elbo) 82 | q_w_mean += learning_rate * bbvi_gradient(score_q_w_mean, elbo) 83 | q_w_var_arg += learning_rate * bbvi_gradient(score_q_w_var_arg, elbo) 84 | t1 = time.time() 85 | # print('time: ', t1 - t0) 86 | 87 | 88 | def bbvi_gradient(score_function, elbo): 89 | grad_elbo = score_function * elbo 90 | cov = leave_one_out_mean(score_function * grad_elbo) 91 | var = leave_one_out_mean(score_function * score_function) 92 | a = cov / var 93 | return np.mean(grad_elbo - score_function * a) 94 | 95 | 96 | def leave_one_out_mean(a): 97 | return (np.sum(a, 0, keepdims=True) - a) / (a.shape[0] - 1.) 98 | 99 | 100 | if __name__ == '__main__': 101 | bbvi() 102 | -------------------------------------------------------------------------------- /distributions/bernoulli.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import numpy as np 3 | import mxnet as mx 4 | 5 | from mxnet import nd 6 | from mxnet import gluon 7 | from typing import Union 8 | from .distributions import BaseDistribution 9 | from common import util 10 | 11 | ZERO = nd.array([0.]) 12 | ONE = nd.array([1.]) 13 | 14 | 15 | class BaseBernoulli(BaseDistribution, metaclass=abc.ABCMeta): 16 | 17 | @property 18 | def is_reparam(self): 19 | return False 20 | 21 | mean = None 22 | logits = None 23 | 24 | def sample(self, n_samples: int = 1) -> nd.NDArray: 25 | mean = self.get_param_not_repeated('mean') 26 | if n_samples == 1: 27 | return nd.sample_uniform(ZERO, ONE, shape=mean.shape) < mean 28 | else: 29 | shape = (n_samples,) + mean.shape 30 | return nd.sample_uniform(ZERO, ONE, shape=shape)[0, :] < mean 31 | 32 | def log_prob(self, x: nd.NDArray) -> nd.NDArray: 33 | logits = self.get_param_maybe_repeated('logits') 34 | if x.ndim > logits.ndim: 35 | logits = nd.expand_dims(logits, 0) 36 | return x * logits - util.softplus(logits) 37 | 38 | 39 | class Bernoulli(BaseBernoulli): 40 | def __init__(self, logits: nd.NDArray) -> None: 41 | super(Bernoulli, self).__init__() 42 | self.logits = logits 43 | 44 | @property 45 | def mean(self): 46 | return util.sigmoid(self.logits) 47 | 48 | 49 | class FastBernoulli(BaseBernoulli): 50 | """Fast parameterization of Bernoulli as in the survival filter paper. 51 | 52 | Complexity O(CK) + O(CK) reduced to O(CK) + O(sK) where s in nonzeros. 53 | 54 | References: 55 | http://auai.org/uai2015/proceedings/papers/246.pdf 56 | """ 57 | 58 | def __init__(self, 59 | positive_latent: nd.NDArray, 60 | weight: nd.NDArray, 61 | bias: nd.NDArray) -> None: 62 | """Number of classes is C; latent dimension K. 63 | 64 | Args: 65 | positive_latent: shape [batch_size, K] positive latent variable 66 | weight: shape [K, C] real-valued weight 67 | """ 68 | super(FastBernoulli, self).__init__() 69 | # mean_arg is of shape [batch_size, C] 70 | self._positive_latent = positive_latent 71 | self._weight = weight 72 | self._bias = bias 73 | self.logits = None 74 | 75 | @property 76 | def mean(self): 77 | arg = nd.dot(self._positive_latent, nd.exp( 78 | self._weight)) + nd.exp(self._bias) 79 | return 1. - nd.exp(-arg) 80 | 81 | def log_prob(self, nonzero_index): 82 | raise NotImplementedError("Not implemented!") 83 | 84 | def log_prob_sum(self, nonzero_index: nd.NDArray) -> nd.NDArray: 85 | """Returns log prob. Argument is batch of indices of nonzero classes. 86 | log p(x) = term_1 + term_2 87 | term_1 = sum_c log p(x_c = 0) 88 | term_2 = sum_{c: x_c = 1} log p(x_c = 1) - log p(x_c = 0) 89 | term_1 takes O(CK) to calculate. 90 | term_2 takes O(CK) + O(sK) with s being the number of nonzero entries in x 91 | """ 92 | mean_arg = -(nd.dot(self._positive_latent, nd.exp(self._weight)) 93 | + nd.exp(self._bias)) 94 | assert mean_arg.shape[1] == 1, "Fast Bernoulli only supports batch size 1!" 95 | mean_arg = mean_arg[:, 0, :] 96 | term_1 = nd.sum(mean_arg, -1) 97 | n_factors, n_classes = self._weight.shape 98 | # weight_nonzero = nd.Embedding( 99 | # nonzero_index, self._weight.T, n_classes, n_factors).T 100 | # nonzero_arg = -nd.dot(self._positive_latent, nd.exp(weight_nonzero)) 101 | # raise NotImplementedError('need to add bias lookup!') 102 | batch_size = mean_arg.shape[0] 103 | nonzero_arg = nd.Embedding( 104 | nonzero_index, mean_arg.T, n_classes, batch_size).T 105 | term_2 = nd.sum(nd.log(1. - nd.exp(nonzero_arg)) - nonzero_arg, -1) 106 | res = term_1 + term_2 107 | return nd.expand_dims(res, 1) 108 | -------------------------------------------------------------------------------- /common/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import mxnet as mx 4 | import numpy as np 5 | import collections 6 | 7 | from mxnet import nd 8 | 9 | 10 | def log_to_file(filename): 11 | logging.basicConfig(level=logging.INFO, 12 | format='%(asctime)s %(name)-4s %(levelname)-4s %(message)s', 13 | datefmt='%m-%d %H:%M', 14 | filename=filename, 15 | filemode='a') 16 | console = logging.StreamHandler() 17 | console.setLevel(logging.INFO) 18 | logging.getLogger('').addHandler(console) 19 | 20 | 21 | def flatten(l): return [item for sublist in l for item in sublist] 22 | 23 | 24 | def softplus(x): 25 | return nd.Activation(x, act_type='softrelu') 26 | 27 | 28 | def np_softplus(x): 29 | return np.log(1. + np.exp(x)) 30 | 31 | 32 | class Softplus(object): 33 | def __init__(self): 34 | pass 35 | 36 | def __call__(self, x: nd.NDArray) -> nd.NDArray: 37 | return nd.Activation(x, act_type='softrelu') 38 | 39 | def backward(self, x): 40 | # d/dx log(1 + exp(x)) = exp(x) / (1 + exp(x)) = 1. / (1. + exp(-x)) 41 | return nd.sigmoid(x) 42 | 43 | 44 | def sigmoid(x): 45 | return nd.Activation(x, act_type='sigmoid') 46 | 47 | 48 | def np_inverse_softplus(x): 49 | return np.log(np.exp(x) - 1.) 50 | 51 | 52 | def latest_checkpoint(directory): 53 | files = [f for f in os.listdir(directory) if 'params' in f] 54 | if len(files) > 0 and any('params' in f for f in files): 55 | l = sorted((int(f.split('-')[-1]), i) for i, f in enumerate(files)) 56 | return os.path.join(directory, files[l[-1][-1]]), l[-1][0] 57 | else: 58 | return None, None 59 | 60 | 61 | def repeat_emb(param, emb): 62 | """Maybe repeat an embedding.""" 63 | res = nd.expand_dims(emb, 0) 64 | param.repeated = nd.repeat(res, repeats=param.n_repeats, axis=0) 65 | param.repeated.attach_grad() 66 | return param.repeated 67 | 68 | 69 | def pathwise_grad_variance_callback(my_model, data_batch): 70 | """Get pathwise gradient estimator variance.""" 71 | param_grads = collections.defaultdict(lambda: []) # type: ignore 72 | params = my_model.collect_params() 73 | n_samples_stats = 10 74 | for i in range(n_samples_stats): 75 | with mx.autograd.record(): 76 | log_q_sum, elbo, sample = my_model(data_batch) 77 | my_model.compute_gradients(elbo, data_batch, log_q_sum) 78 | for name, param in params.items(): 79 | if param.grad_req != 'null': 80 | param_grads[name].append(param.grad().asnumpy()) 81 | 82 | 83 | def callback_elbo_sample(my_model, data_batch): 84 | """Get a reduced-variance estimate of the elbo and sample.""" 85 | n_samples_stats = 10 86 | _, elbo, sample = my_model(data_batch) 87 | for _ in range(n_samples_stats): 88 | tmp_sample = nd.zeros_like(sample) 89 | tmp_elbo = nd.zeros_like(elbo) 90 | for _ in range(n_samples_stats): 91 | _, elbo, sample = my_model(data_batch) 92 | tmp_sample += sample 93 | tmp_elbo += elbo 94 | tmp_sample /= n_samples_stats 95 | tmp_elbo /= n_samples_stats 96 | tmp_sample = np.mean(tmp_sample.asnumpy(), 0) 97 | tmp_elbo = np.mean(tmp_elbo.asnumpy()) 98 | return tmp_elbo, tmp_sample 99 | 100 | 101 | def score_grad_variance_callback(my_model): 102 | """Get score function gradient variance.""" 103 | params = my_model.collect_params() 104 | param_grads = collections.defaultdict(lambda: []) # type: ignore 105 | for name, param in params.items(): 106 | if param.grad_req != 'null': 107 | grads = np.stack(param_grads[name]) 108 | param.grad_variance = np.mean(np.var(grads, axis=0)) 109 | param.grad_norm = np.mean(np.linalg.norm(grads, axis=-1)) 110 | for block in my_model.sequential._children: 111 | print(block.name, ':') 112 | print([(name, p.data().asnumpy().tolist()) 113 | for name, p in filter( 114 | lambda x: 'weight' in x[0] or 'bias' in x[0], 115 | block.collect_params().items())]) 116 | for child_block in block._children: 117 | print(child_block.name, ':') 118 | print('mean:', child_block.get_param_not_repeated('mean').asnumpy()) 119 | if hasattr(child_block, 'variance'): 120 | print('variance: ', child_block.get_param_not_repeated( 121 | 'variance').asnumpy()) 122 | -------------------------------------------------------------------------------- /distributions/gaussian.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import numpy as np 3 | import abc 4 | 5 | from .distributions import BaseDistribution 6 | from typing import Union 7 | from typing import Tuple 8 | from mxnet import gluon 9 | from mxnet import nd 10 | from common import util 11 | 12 | 13 | class BaseGaussian(BaseDistribution, metaclass=abc.ABCMeta): 14 | 15 | @property 16 | def is_reparam(self) -> bool: 17 | return True 18 | 19 | mean = None 20 | variance = None 21 | 22 | def sample(self, n_samples: int = 1) -> nd.NDArray: 23 | # reparameterization / pathwise trick for backprop 24 | mean = self.get_param_not_repeated('mean') 25 | variance = self.get_param_not_repeated('variance') 26 | shape = (n_samples,) + mean.shape 27 | return mean + nd.sqrt(variance) * nd.random_normal(shape=shape) 28 | 29 | def log_prob(self, x: nd.NDArray) -> nd.NDArray: 30 | mean = self.get_param_maybe_repeated('mean') 31 | variance = self.get_param_maybe_repeated('variance') 32 | if x.ndim > mean.ndim: 33 | mean = nd.expand_dims(mean, 0) 34 | variance = nd.expand_dims(variance, 0) 35 | diff = x - mean 36 | self._saved_for_backward = [diff] 37 | return (-0.5 * nd.log(2. * np.pi * variance) 38 | - nd.square(diff) / 2. / variance) 39 | 40 | 41 | class Gaussian(BaseGaussian): 42 | def __init__(self, 43 | mean: nd.NDArray, 44 | variance: nd.NDArray) -> None: 45 | super(Gaussian, self).__init__() 46 | self.mean = mean 47 | self.variance = variance 48 | 49 | 50 | class PriorGaussian(BaseGaussian): 51 | def __init__(self, 52 | name: str, 53 | shape: Union[int, tuple], 54 | mean: float = 0., 55 | variance: float = 1.) -> None: 56 | super(PriorGaussian, self).__init__() 57 | with self.name_scope(): 58 | self.mean = self.params.get(name + '_mean', 59 | init=mx.init.Constant(mean), 60 | shape=shape, 61 | grad_req='null') 62 | self.variance = self.params.get(name + '_variance', 63 | init=mx.init.Constant(variance), 64 | shape=shape, 65 | grad_req='null') 66 | 67 | def __call__(self, 68 | z_above: nd.NDArray, 69 | weight: nd.NDArray, 70 | bias: nd.NDArray) -> Gaussian: 71 | """Call the prior layer at a lower layer of a DEF.""" 72 | mean = nd.dot(z_above, weight) + bias 73 | variance = self.variance 74 | return Gaussian(mean, variance) 75 | 76 | 77 | class VariationalGaussian(BaseGaussian): 78 | def __init__(self, 79 | name: str, 80 | shape: tuple, 81 | mean: float = 0., 82 | variance: float = 1., 83 | grad_req: str = 'write') -> None: 84 | super(VariationalGaussian, self).__init__() 85 | mean_init = mx.init.Constant(mean) 86 | with self.name_scope(): 87 | self.mean = self.params.get(name + '_mean', init=mean_init, shape=shape, 88 | grad_req=grad_req) 89 | variance_arg_init = mx.init.Constant(util.np_inverse_softplus(variance)) 90 | self.variance = self.params.get(name + '_variance_arg', 91 | init=variance_arg_init, shape=shape, 92 | grad_req=grad_req) 93 | self.variance.link_function = util.softplus 94 | 95 | 96 | class VariationalLookupGaussian(BaseGaussian): 97 | def __init__(self, 98 | name: str, 99 | shape: tuple, 100 | mean: float = 0., 101 | variance: float = 1., 102 | grad_req: str = 'write') -> None: 103 | """Mean-field factorized variational Gaussian. Per-datapoint parameters.""" 104 | super(VariationalLookupGaussian, self).__init__() 105 | with self.name_scope(): 106 | mean_init = mx.init.Constant(mean) 107 | self._mean_emb = self.params.get( 108 | name + '_mean_emb', init=mean_init, shape=shape, grad_req=grad_req) 109 | variance_arg_init = mx.init.Constant(util.np_inverse_softplus(variance)) 110 | self._variance_arg_emb = self.params.get( 111 | name + '_variance_arg_emb', init=variance_arg_init, shape=shape, 112 | grad_req=grad_req) 113 | self.link_function = util.Softplus() 114 | 115 | def lookup(self, data_index: nd.NDArray): 116 | """Return the distribution for the data batch.""" 117 | shape = self._mean_emb.shape 118 | self.mean = nd.Embedding(data_index, self._mean_emb.data(), *shape) 119 | variance_arg = nd.Embedding( 120 | data_index, self._variance_arg_emb.data(), *shape) 121 | self.variance = self.link_function(variance_arg) 122 | if hasattr(self._mean_emb, 'n_repeats'): 123 | self.mean_repeated = util.repeat_emb(self._mean_emb, self.mean) 124 | if hasattr(self._variance_arg_emb, 'n_repeats'): 125 | self.variance_repeated = self.link_function( 126 | util.repeat_emb(self._variance_arg_emb, variance_arg)) 127 | return self 128 | -------------------------------------------------------------------------------- /distributions/poisson.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import numpy as np 3 | import mxnet as mx 4 | 5 | from mxnet import nd 6 | from mxnet import gluon 7 | from typing import Union 8 | from typing import Callable 9 | from .distributions import BaseDistribution 10 | from common import util 11 | 12 | 13 | class BasePoisson(BaseDistribution, metaclass=abc.ABCMeta): 14 | 15 | @property 16 | def is_reparam(self): 17 | return False 18 | 19 | mean = None 20 | 21 | def sample(self, n_samples: int = 1) -> nd.NDArray: 22 | mean = self.get_param_not_repeated('mean') 23 | if n_samples == 1: 24 | res = nd.sample_poisson(mean) 25 | else: 26 | res = nd.sample_poisson(mean, n_samples) 27 | res = nd.transpose(res) 28 | if res.ndim == 3: 29 | return nd.swapaxes(res, 1, 2) 30 | elif res.ndim == 2: 31 | return res 32 | else: 33 | raise ValueError('Ambiguous sample shape.') 34 | 35 | def log_prob(self, x: nd.NDArray) -> nd.NDArray: 36 | mean = self.get_param_maybe_repeated('mean') 37 | if x.ndim > mean.ndim: 38 | mean = nd.expand_dims(mean, 0) 39 | np_x = x.asnumpy().astype(np.int32).astype(np.float32) 40 | np.testing.assert_almost_equal(x.asnumpy(), np_x) 41 | return x * nd.log(mean) - mean - nd.gammaln(x + 1.) 42 | 43 | 44 | class Poisson(BasePoisson): 45 | def __init__(self, mean: nd.NDArray) -> None: 46 | super(Poisson, self).__init__() 47 | self._mean = mean 48 | 49 | @property 50 | def mean(self): 51 | return self._mean 52 | 53 | 54 | class PriorPoisson(BasePoisson): 55 | def __init__(self, 56 | name: str, 57 | shape: Union[int, tuple], 58 | mean: float = 1.) -> None: 59 | super(PriorPoisson, self).__init__() 60 | with self.name_scope(): 61 | assert mean > 0. 62 | self.mean = self.params.get(name, 63 | shape=shape, 64 | init=mx.init.Constant(mean), 65 | grad_req='null') 66 | 67 | def __call__(self, 68 | z_above: nd.NDArray, 69 | weight: nd.NDArray, 70 | bias: nd.NDArray) -> Poisson: 71 | """Call the prior layer at a lower layer of a DEF.""" 72 | mean = util.softplus(nd.dot(z_above, weight) + bias) 73 | return Poisson(mean) 74 | 75 | 76 | class VariationalPoisson(BasePoisson): 77 | def __init__(self, name: str, shape: tuple, mean: float = 1.) -> None: 78 | super(VariationalPoisson, self).__init__() 79 | with self.name_scope(): 80 | mean_arg_init = mx.init.Constant(util.np_inv_softplus(mean)) 81 | self.mean = self.params.get( 82 | name + '_mean_arg', init=mean_arg_init, shape=shape) 83 | self.mean.link_function = util.softplus 84 | 85 | 86 | class VariationalLookupPoisson(BasePoisson): 87 | def __init__(self, 88 | name: str, 89 | shape: tuple, 90 | mean: float = 1., 91 | init_pca: np.array = None) -> None: 92 | """Mean-field factorized variational Poisson. Per-datapoint parameters.""" 93 | super(VariationalLookupPoisson, self).__init__() 94 | with self.name_scope(): 95 | assert mean > 0. 96 | if init_pca is not None: 97 | # assert np.all(init_pca > 0) 98 | pca_mean = np.mean(util.np_softplus(init_pca), -1) 99 | correction = util.np_inverse_softplus(mean - pca_mean) 100 | init_pca_arg = init_pca + np.expand_dims(correction, -1) 101 | mean_arg_init = mx.init.Constant(init_pca_arg) 102 | else: 103 | mean_arg_init = UniformInit(mean=mean, 104 | scale=0.07, 105 | inverse_transform=util.np_inverse_softplus) 106 | self._mean_arg_emb = self.params.get( 107 | name + '_mean_arg_emb', init=mean_arg_init, shape=shape) 108 | self.link_function = util.softplus 109 | 110 | def lookup(self, labels: nd.NDArray, repeat: bool =True): 111 | """Return the distribution for the data batch.""" 112 | shape = self._mean_arg_emb.shape 113 | mean_arg_emb = nd.Embedding(labels, self._mean_arg_emb.data(), *shape) 114 | self.mean = nd.maximum(5e-3, self.link_function(mean_arg_emb)) 115 | if hasattr(self._mean_arg_emb, 'n_repeats') and repeat: 116 | self.mean_repeated = self.link_function( 117 | util.repeat_emb(self._mean_arg_emb, mean_arg_emb)) 118 | return self 119 | 120 | 121 | @mx.init.register 122 | class UniformInit(mx.init.Initializer): 123 | def __init__(self, 124 | mean: float = 0., 125 | scale: float = 0.07, 126 | inverse_transform: Callable = None) -> None: 127 | super(UniformInit, self).__init__() 128 | if inverse_transform is not None: 129 | self.low = inverse_transform(mean - scale) 130 | self.high = inverse_transform(mean + scale) 131 | else: 132 | self.low = mean - scale 133 | self.high = mean + scale 134 | 135 | def _init_weight(self, _, arr): 136 | arr[:] = nd.random.uniform(low=self.low, high=self.high, shape=arr.shape) 137 | 138 | def _init_bias(self, _, arr): 139 | arr[:] = nd.zeros(shape=arr.shape) 140 | -------------------------------------------------------------------------------- /tests/test_multivariate_def.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tester 3 | 4 | 5 | def test_one_layer_gaussian(): 6 | """Test variational EM with pathwise gradients.""" 7 | config = """ 8 | learning_rate: 0.1 9 | n_iterations: 100 10 | gradient: 11 | estimator: pathwise 12 | n_samples: 7 13 | batch_size: 1 14 | layer_1: 15 | latent_distribution: gaussian 16 | p_z_variance: 1. 17 | size: 3 18 | layer_0: 19 | weight_distribution: point_mass 20 | p_w_variance: 1. 21 | data_size: 1 22 | """ 23 | data = np.array([[3.3]]) 24 | 25 | def test_posterior_predictive(sample: np.array) -> None: 26 | print('data:', data) 27 | print('posterior predictive sample:', sample) 28 | np.testing.assert_allclose(sample, data, rtol=1e-1) 29 | 30 | tester.test( 31 | config, data=data, test_fn=test_posterior_predictive) 32 | 33 | 34 | def test_one_layer_gaussian_score(): 35 | """Test variational EM with score function gradients.""" 36 | config = """ 37 | learning_rate: 0.1 38 | n_iterations: 100 39 | print_every: 100 40 | gradient: 41 | estimator: score_function 42 | n_samples: 32 43 | batch_size: 1 44 | layer_1: 45 | latent_distribution: gaussian 46 | size: 7 47 | layer_0: 48 | weight_distribution: point_mass 49 | data_size: 1 50 | """ 51 | data = np.array([[30.3]]) 52 | 53 | def test_posterior_predictive(sample: np.array) -> None: 54 | print('data:', data) 55 | print('posterior predictive sample:', sample) 56 | np.testing.assert_allclose(sample, data, rtol=1e-1) 57 | 58 | tester.test( 59 | config, data=data, test_fn=test_posterior_predictive) 60 | 61 | 62 | def test_multivariate_data(): 63 | """Test variational EM with score function gradients.""" 64 | config = """ 65 | learning_rate: 0.1 66 | n_iterations: 100 67 | print_every: 100 68 | gradient: 69 | estimator: score_function 70 | n_samples: 16 71 | batch_size: 1 72 | layer_1: 73 | latent_distribution: gaussian 74 | size: 1 75 | layer_0: 76 | weight_distribution: point_mass 77 | data_size: 3 78 | """ 79 | # shape [n_data, data_size] 80 | data = np.array([[30.3, -10., 5.]]) 81 | 82 | def test_posterior_predictive(sample: np.array) -> None: 83 | print('data:', data) 84 | print('posterior predictive sample:', sample) 85 | np.testing.assert_allclose(sample, data, rtol=1e-1) 86 | 87 | tester.test( 88 | config, data=data, test_fn=test_posterior_predictive) 89 | 90 | 91 | def test_latent_poisson_layer(): 92 | """Test matching q(z) to p(z) where p is Poisson.""" 93 | mean = 5. 94 | 95 | config = """ 96 | n_iterations: 100 97 | learning_rate: 0.1 98 | gradient: 99 | estimator: score_function 100 | n_samples: 16 101 | batch_size: 1 102 | layer_1: 103 | latent_distribution: poisson 104 | size: 3 105 | p_z_mean: {} 106 | """.format(mean) 107 | 108 | def test_posterior_predictive(sample: np.array) -> None: 109 | print('posterior predictive sample:', sample) 110 | print('prior mean:', mean) 111 | np.testing.assert_allclose(sample, mean, rtol=1e-1) 112 | 113 | tester.test(config, data=np.array([np.nan]), 114 | test_fn=test_posterior_predictive) 115 | 116 | 117 | def test_one_layer_poisson(): 118 | """Test variational EM with score function gradients and poisson latents.""" 119 | config = """ 120 | learning_rate: 0.1 121 | n_iterations: 100 122 | print_every: 100 123 | gradient: 124 | estimator: score_function 125 | n_samples: 32 126 | batch_size: 1 127 | layer_1: 128 | latent_distribution: poisson 129 | size: 3 130 | layer_0: 131 | weight_distribution: point_mass 132 | data_size: 1 133 | """ 134 | data = np.array([[30.3]]) 135 | 136 | def test_posterior_predictive(sample: np.array) -> None: 137 | print('data:', data) 138 | print('posterior predictive sample:', sample) 139 | np.testing.assert_allclose(sample, data, rtol=0.3) 140 | 141 | tester.test( 142 | config, data=data, test_fn=test_posterior_predictive) 143 | 144 | 145 | def test_two_layer_gaussian_score(): 146 | """Test variational EM with score function gradients and gaussian latents.""" 147 | config = """ 148 | learning_rate: 0.1 149 | n_iterations: 100 150 | gradient: 151 | estimator: score_function 152 | n_samples: 32 153 | batch_size: 1 154 | layer_2: 155 | latent_distribution: gaussian 156 | size: 2 157 | layer_1: 158 | latent_distribution: gaussian 159 | weight_distribution: point_mass 160 | size: 3 161 | layer_0: 162 | weight_distribution: point_mass 163 | data_size: 1 164 | """ 165 | data = np.array([[8.5]]) 166 | 167 | def test_posterior_predictive(sample: np.array) -> None: 168 | print('data:', data) 169 | print('posterior predictive sample:', sample) 170 | np.testing.assert_allclose(sample, data, rtol=1e-1) 171 | 172 | tester.test( 173 | config, data=data, test_fn=test_posterior_predictive) 174 | 175 | 176 | def test_two_layer_poisson_score(): 177 | """Test variational EM with score function gradients and gaussian latents.""" 178 | config = """ 179 | learning_rate: 0.1 180 | n_iterations: 100 181 | gradient: 182 | estimator: score_function 183 | n_samples: 32 184 | batch_size: 1 185 | layer_2: 186 | latent_distribution: poisson 187 | size: 2 188 | layer_1: 189 | latent_distribution: poisson 190 | weight_distribution: point_mass 191 | size: 3 192 | layer_0: 193 | weight_distribution: point_mass 194 | data_size: 1 195 | """ 196 | data = np.array([[-3.2]]) 197 | 198 | def test_posterior_predictive(sample: np.array) -> None: 199 | print('data:', data) 200 | print('posterior predictive sample:', sample) 201 | np.testing.assert_allclose(sample, data, rtol=0.3) 202 | 203 | tester.test( 204 | config, data=data, test_fn=test_posterior_predictive) 205 | 206 | 207 | if __name__ == '__main__': 208 | test_latent_poisson_layer() 209 | -------------------------------------------------------------------------------- /deep_exp_fam/layers.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import distributions 3 | 4 | from mxnet import nd 5 | from mxnet import gluon 6 | from mxnet import autograd 7 | from typing import Tuple 8 | from typing import Callable 9 | from common import util 10 | 11 | 12 | def _filter_kwargs(prefix, kwargs): 13 | """Return new kwargs matching a prefix, or empty dict if no matches.""" 14 | tmp_kwargs = {} 15 | for kwarg, value in kwargs.items(): 16 | if prefix in kwarg: 17 | kwarg = kwarg.lstrip(prefix) 18 | kwarg = kwarg.lstrip('_') 19 | tmp_kwargs[kwarg] = value 20 | return tmp_kwargs 21 | 22 | 23 | def _build_distribution( 24 | name: str, 25 | distribution_type: str, 26 | shape: tuple, 27 | **kwargs) -> distributions.BaseDistribution: 28 | kwargs = _filter_kwargs(name, kwargs) 29 | if name.startswith('p_'): 30 | attr = 'Prior' 31 | elif name.startswith('q_'): 32 | attr = 'Variational' 33 | if attr == 'Variational' and name.endswith('_z'): 34 | # assume mean-field variational inference with per-datapoint parameters 35 | attr += 'Lookup' 36 | attr += distribution_type.capitalize() 37 | dist_class = getattr(distributions, attr) 38 | return dist_class(name, shape, **kwargs) 39 | 40 | 41 | class LatentLayer(gluon.Block): 42 | def __init__(self, 43 | latent_distribution: str, 44 | n_data: int, 45 | size: int, 46 | size_above: int, 47 | gradient_config: dict, 48 | weight_distribution: str = None, 49 | **kwargs) -> None: 50 | super(LatentLayer, self).__init__() 51 | self.gradient_config = gradient_config 52 | with self.name_scope(): 53 | self.size_above = size_above 54 | self.p_z = _build_distribution( 55 | 'p_z', latent_distribution, (1, size), **kwargs) 56 | self.q_z = _build_distribution( 57 | 'q_z', latent_distribution, (n_data, size), **kwargs) 58 | if size_above is not None: 59 | weight_shape = (size_above, size) 60 | if weight_distribution == 'point_mass': 61 | self.params.get('point_mass_weight', shape=weight_shape) 62 | self.params.get('point_mass_bias', shape=( 63 | size,), init=mx.init.Zero()) 64 | else: 65 | raise NotImplementedError( 66 | 'Need to implement non point-mass weight distribution!') 67 | 68 | def forward(self, 69 | data: tuple, 70 | log_q_sum: nd.NDArray, 71 | elbo_above: nd.NDArray, 72 | z_above: nd.NDArray, 73 | ) -> Tuple[nd.NDArray, nd.NDArray, nd.NDArray]: 74 | n_samples = self.gradient_config['n_samples'] 75 | q_z = self.q_z.lookup(data[1]) 76 | z_sample = q_z.sample(n_samples) 77 | if self.gradient_config['estimator'] == 'score_function': 78 | # important -- do not differentiate through latent sample 79 | z_sample = nd.stop_gradient(z_sample) 80 | log_q_z = nd.sum(q_z.log_prob(z_sample), -1) 81 | elbo = elbo_above if elbo_above is not None else 0. 82 | log_q_sum = log_q_sum if log_q_sum is not None else 0. 83 | log_q_sum = log_q_sum + log_q_z 84 | if self.size_above is None: 85 | # first (top) layer, without weights 86 | log_p_z = nd.sum(self.p_z.log_prob(z_sample), -1) 87 | elbo = elbo + log_p_z - log_q_z 88 | elif self.size_above is not None: 89 | w = self.params.get('point_mass_weight').data() 90 | b = self.params.get('point_mass_bias').data() 91 | log_p_z = nd.sum(self.p_z(z_above, w, b).log_prob(z_sample), -1) 92 | elbo = elbo + log_p_z - log_q_z 93 | assert log_q_sum.shape[0] == n_samples 94 | assert elbo.shape[0] == n_samples 95 | return log_q_sum, elbo, z_sample 96 | 97 | 98 | class ObservationLayer(gluon.Block): 99 | def __init__(self, 100 | weight_distribution: str, 101 | size_above: int, 102 | data_size: int, 103 | gradient_config: dict, 104 | data_distribution: str = 'gaussian', 105 | **kwargs) -> None: 106 | super(ObservationLayer, self).__init__() 107 | self.gradient_config = gradient_config 108 | self.data_distribution = data_distribution 109 | with self.name_scope(): 110 | weight_shape = (size_above, data_size) 111 | if weight_distribution == 'point_mass': 112 | # w_init = UniformInit(mean=-4.) 113 | self.params.get('point_mass_weight', shape=weight_shape) 114 | self.params.get( 115 | 'point_mass_bias', shape=(data_size,), init=mx.init.Zero()) 116 | else: 117 | raise NotImplementedError( 118 | 'Weights and biases other than point mass not implemented!') 119 | 120 | def p_x_fn(self, 121 | z_above: nd.NDArray, 122 | weight: nd.NDArray, 123 | bias: nd.NDArray = None) -> distributions.BaseDistribution: 124 | # z_above: [n_samples, batch_size, size_above] 125 | # weight: [size_above, data_size] 126 | if self.data_distribution == 'gaussian': 127 | params = nd.dot(z_above, weight) + bias 128 | variance = nd.ones_like(params) 129 | return distributions.Gaussian(params, variance) 130 | elif self.data_distribution == 'bernoulli': 131 | params = nd.dot(z_above, weight) + bias 132 | return distributions.Bernoulli(logits=params) 133 | elif self.data_distribution == 'poisson': 134 | # minimum intercept is 0.01 135 | return distributions.Poisson( 136 | 0.01 + nd.dot(z_above, util.softplus(weight))) 137 | else: 138 | raise ValueError( 139 | 'Incompatible data distribution: %s' % self.data_distribution) 140 | 141 | def forward(self, 142 | data: tuple, 143 | log_q_sum: nd.NDArray, 144 | elbo_above: nd.NDArray, 145 | z_above: nd.NDArray) -> Tuple[nd.NDArray, nd.NDArray, nd.NDArray]: 146 | n_samples = self.gradient_config['n_samples'] 147 | w = self.params.get('point_mass_weight').data() 148 | b = self.params.get('point_mass_bias').data() 149 | p_x = self.p_x_fn(z_above, w, b) 150 | log_p_x = nd.sum(p_x.log_prob(data[0]), -1) 151 | elbo = elbo_above + log_p_x 152 | return log_q_sum, elbo, p_x.mean 153 | -------------------------------------------------------------------------------- /common/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mxnet as mx 3 | import numpy as np 4 | import typing 5 | import gensim 6 | import scipy.sparse 7 | import sklearn.decomposition 8 | import sklearn.metrics.pairwise 9 | import logging 10 | logger = logging.getLogger(__name__) 11 | 12 | from mxnet import nd 13 | from mxnet import gluon 14 | 15 | 16 | def print_top_words(weight: gluon.Parameter, 17 | id2word: dict, 18 | top: int = 10) -> None: 19 | n_factors, vocab_size = weight.shape 20 | weight = weight.data().asnumpy() 21 | for factor_idx in range(n_factors): 22 | top_word_indices = np.argsort(weight[factor_idx])[::-1][0:top] 23 | logger.info('----------') 24 | logger.info('factor %d:' % factor_idx) 25 | for word_idx in top_word_indices: 26 | logger.info('%.3e\t%s' % 27 | (weight[factor_idx, word_idx], id2word[word_idx])) 28 | 29 | 30 | def tokenize_text(fname: str, 31 | vocab_size: int, 32 | invalid_label: int = -1, 33 | start_label: int = 0) -> typing.Tuple[list, dict]: 34 | """Get tokenized sentences and vocab.""" 35 | if not os.path.isfile(fname): 36 | raise IOError('Data file is not a file! Got: %s' % fname) 37 | lines = open(fname).readlines() 38 | lines = [line.rstrip('\n') for line in lines] 39 | # lines = [filter(None, i.split(' ')) for i in lines] 40 | lines = [i.split(' ') for i in lines] 41 | vocab = gensim.corpora.dictionary.Dictionary(lines) 42 | vocab.filter_extremes(no_below=0, no_above=1, keep_n=vocab_size) 43 | vocab = {v: k for k, v in vocab.items()} 44 | lines = [[w for w in sent if w in vocab] for sent in lines] 45 | sentences, vocab = mx.rnn.encode_sentences( 46 | lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label) 47 | sentences = [sent for sent in sentences if len(sent) > 0] 48 | return sentences, vocab 49 | 50 | 51 | def flatten(l): return [item for sublist in l for item in sublist] 52 | 53 | 54 | def principal_components(sentences: list, 55 | vocab_size: int, 56 | n_components: int) -> list: 57 | """PCA on list of integers.""" 58 | # sparse format 59 | row_ind = flatten( 60 | [[i] * len(sentence) for i, sentence in enumerate(sentences)]) 61 | col_ind = flatten(sentences) 62 | shape = (len(sentences), vocab_size) 63 | data = np.ones(len(col_ind)) 64 | X = scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=shape) 65 | X_std = X - X.mean(axis=0) 66 | X_std = X_std / (1e-8 + np.std(X_std, 0)) 67 | Y = sklearn.decomposition.PCA(n_components=n_components).fit_transform(X_std) 68 | return Y 69 | 70 | 71 | def print_nearest_cosine_distance(embeddings: gluon.Parameter, 72 | id2word: dict, 73 | num: int = 10) -> None: 74 | embeddings = embeddings.data().asnumpy().T 75 | top_wordids = list(id2word.keys())[0:num] 76 | distances = sklearn.metrics.pairwise.cosine_similarity( 77 | embeddings[top_wordids], embeddings) 78 | for idx, distance in zip(top_wordids, distances): 79 | top_word_indices = np.argsort(distance)[::-1][1:11] 80 | logger.info('----------') 81 | logger.info("nearest words in cosine distance to: %s" % id2word[idx]) 82 | for nearest in top_word_indices: 83 | logger.info('%.3e\t%s' % (distance[nearest], id2word[nearest])) 84 | 85 | 86 | def _batchify_sentences(data: list, 87 | vocab_size: int) -> typing.Tuple[nd.sparse.CSRNDArray, nd.NDArray]: 88 | """Collate data, a list of sentence, label tuples into a sparse batch.""" 89 | indptr = [0] # row offsets 90 | indices = [] 91 | labels = [] 92 | for row_idx, sentence_and_label in enumerate(data): 93 | sentence, label = sentence_and_label 94 | ptr = indptr[row_idx] + len(sentence) 95 | indptr.append(ptr) 96 | indices.extend(sentence) 97 | labels.append(label) 98 | values = [1] * len(indices) 99 | labels = nd.array(labels) 100 | batch = nd.sparse.csr_matrix(data=values, 101 | indices=indices, 102 | indptr=indptr, 103 | shape=(len(data), vocab_size)) 104 | return batch, labels 105 | 106 | 107 | class SentenceDataLoader(gluon.data.DataLoader): 108 | def __init__(self, 109 | id2word: dict, 110 | principal_components: np.array=None, 111 | data_distribution: str='bernoulli', 112 | **kwargs) -> None: 113 | super(SentenceDataLoader, self).__init__(**kwargs) 114 | self.id2word = id2word 115 | self.principal_components = principal_components 116 | self.batch_size = kwargs['batch_size'] 117 | self.data_distribution = data_distribution 118 | if data_distribution == 'indices' and self.batch_size != 1: 119 | raise ValueError( 120 | "Need batch size of 1 for variable-length index representation!") 121 | self.vocab_size = len(id2word) 122 | 123 | def __iter__(self): 124 | for batch in self._batch_sampler: 125 | if self.data_distribution == 'bernoulli': 126 | yield _batchify_sentences( 127 | [self._dataset[idx] for idx in batch], self.vocab_size) 128 | elif self.data_distribution == 'fast_bernoulli': 129 | res = [self._dataset[idx] for idx in batch] 130 | assert len(res) == 1 131 | data, label = res[0] 132 | yield nd.array(data), nd.array([label]) 133 | 134 | 135 | def _batchify_documents(data: list, 136 | vocab_size: int) -> typing.Tuple[nd.sparse.CSRNDArray, nd.NDArray]: 137 | """Collate data, a list of sentence, label tuples into a sparse batch.""" 138 | indptr = [0] # row offsets 139 | indices = [] 140 | labels = [] 141 | values = [] 142 | for row_idx, doc_and_label in enumerate(data): 143 | doc, label = doc_and_label 144 | ptr = indptr[row_idx] + len(doc) 145 | indptr.append(ptr) 146 | word_ids, counts = zip(*doc) 147 | indices.extend(word_ids) 148 | values.extend(counts) 149 | labels.append(label) 150 | labels = nd.array(labels).astype(np.int64) 151 | batch = nd.sparse.csr_matrix(data=values, 152 | indices=indices, 153 | indptr=indptr, 154 | shape=(len(data), vocab_size)) 155 | return batch, labels 156 | 157 | 158 | class DocumentDataLoader(gluon.data.DataLoader): 159 | def __init__(self, 160 | id2word: dict, 161 | **kwargs) -> None: 162 | super(DocumentDataLoader, self).__init__(**kwargs) 163 | self.id2word = id2word 164 | self.vocab_size = len(id2word) 165 | 166 | def __iter__(self): 167 | for batch in self._batch_sampler: 168 | yield _batchify_documents( 169 | [self._dataset[idx] for idx in batch], self.vocab_size) 170 | -------------------------------------------------------------------------------- /deep_exp_fam/deep_exponential_family_model.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import numpy as np 3 | import mxnet as mx 4 | assert mx.__version__ > '0.11.0' # required for autograd.grad() 5 | 6 | from mxnet import nd 7 | from mxnet import gluon 8 | from mxnet import autograd 9 | from typing import Tuple 10 | 11 | from .layers import LatentLayer 12 | from .layers import ObservationLayer 13 | 14 | EPSILON = np.finfo(float).eps 15 | 16 | 17 | class DeepExponentialFamilyModel(object): 18 | def __init__(self, gradient_config: dict, n_data: int) -> None: 19 | self.n_data = n_data 20 | self.gradient_config = gradient_config 21 | self.sequential = gluon.nn.Sequential() 22 | self.size_above = None 23 | self._grads_attached = False 24 | 25 | def name_scope(self): 26 | return self.sequential.name_scope() 27 | 28 | def add(self, block, **kwargs): 29 | if block == LatentLayer: 30 | block = block(n_data=self.n_data, 31 | size_above=self.size_above, 32 | gradient_config=self.gradient_config, 33 | **kwargs) 34 | self.sequential.add(block) 35 | self.size_above = kwargs['size'] 36 | elif block == ObservationLayer: 37 | self.sequential.add( 38 | block(size_above=self.size_above, 39 | gradient_config=self.gradient_config, 40 | **kwargs)) 41 | else: 42 | raise ValueError('Unknown block type: %s' % type(block)) 43 | 44 | def maybe_attach_repeated_params(self): 45 | """Attach repeated params if using score function estimator gradient.""" 46 | cfg = self.gradient_config 47 | params = self.collect_params() 48 | point_params = [p for name, p in params.items() if 'point' in name] 49 | self._point_mass_params = point_params 50 | if cfg['estimator'] == 'pathwise': 51 | self._grads_attached = True 52 | return 53 | elif cfg['estimator'] == 'score_function': 54 | for name, param in params.items(): 55 | if 'point_mass' not in name and param.grad_req != 'null': 56 | assert cfg['n_samples'] >= 3, "Require n_samples >=3 for gradient." 57 | if 'emb' in name: 58 | param.n_repeats = cfg['n_samples'] 59 | # this seems to be slower than the dense version! 60 | # param._data[0] = param._data[0].tostype('row_sparse') 61 | # autograd.mark_variables( 62 | # param._data[0], nd.zeros(param.shape).tostype('row_sparse')) 63 | else: 64 | res = nd.repeat(param.data(), repeats=cfg['batch_size'], axis=0) 65 | res = nd.expand_dims(res, 0) 66 | param.repeated = nd.repeat(res, repeats=cfg['n_samples'], axis=0) 67 | # request gradient with respect to each sample and batch datapoint 68 | param.repeated.attach_grad() 69 | param.score_grad = True 70 | score_params = [p for p in params.values() if hasattr(p, 'score_grad')] 71 | self._score_params = score_params 72 | self._params = params 73 | self._grads_attached = True 74 | 75 | def collect_params(self): 76 | return self.sequential.collect_params() 77 | 78 | def save_params(self, *args): 79 | self.sequential.save_params(*args) 80 | 81 | def load_params(self, *args): 82 | self.sequential.load_params(*args) 83 | 84 | def __call__(self, *args, **kwargs): 85 | return self.forward(*args, **kwargs) 86 | 87 | def forward(self, 88 | data_batch: mx.io.DataBatch, 89 | get_sample: bool = True, 90 | ) -> Tuple[nd.NDArray, nd.NDArray, nd.NDArray]: 91 | """Return ELBO, reconstruction or z_sample from the DEF.""" 92 | if not self._grads_attached: 93 | raise ValueError('Must call maybe_attach_repeated_params() first!') 94 | log_q_sum, elbo, sample = None, None, None 95 | for i, block in enumerate(self.sequential._children): 96 | log_q_sum, elbo, sample = block(data_batch, log_q_sum, elbo, sample) 97 | return log_q_sum, elbo, sample 98 | 99 | def compute_gradients(self, 100 | elbo: nd.NDArray, 101 | data_batch: mx.io.DataBatch = None, 102 | log_q_sum: nd.NDArray = None, 103 | mode: str = 'train') -> None: 104 | """Compute gradients and assign them to variational parameters. 105 | 106 | Args: 107 | elbo: evidence lower bound that we maximize 108 | data_batch: minibatch of data with data indices as labels 109 | log_q_sum: sum of log probs of samples from variational distributions q. 110 | """ 111 | cfg = self.gradient_config 112 | if cfg['estimator'] == 'pathwise': 113 | for block in self.sequential._children: 114 | for child_block in block._children: 115 | if hasattr(child_block, 'is_reparam'): 116 | assert child_block.is_reparam == True 117 | if len(self._point_mass_params) > 0 and mode == 'train': 118 | variables = [p.data() for p in self._point_mass_params] 119 | assert elbo.shape[-1] == cfg['batch_size'] 120 | loss = nd.mean(-elbo, -1) 121 | point_mass_grads = autograd.grad(loss, variables, retain_graph=True) 122 | _assign_grads(self._point_mass_params, point_mass_grads) 123 | if cfg['estimator'] == 'pathwise': 124 | (-elbo).backward() 125 | elif cfg['estimator'] == 'score_function': 126 | variables = [param.repeated for param in self._score_params] 127 | score_functions = autograd.grad(log_q_sum, variables) 128 | mx.autograd.set_recording(False) 129 | score_grads = [] 130 | for param, score_function in zip(self._score_params, score_functions): 131 | grad = _leave_one_out_gradient_estimator(score_function, -elbo) 132 | if 'emb' in param.name: 133 | # turns out the sparse implementation is not faster?! 134 | # data, label = data_batch 135 | # label = label.astype(np.int64) 136 | # grad = nd.sparse.row_sparse_array( 137 | # grad, indices=label, shape=param.shape) 138 | # need to broadcast for embeddings 139 | one_hot = nd.one_hot(data_batch[1], depth=self.n_data) 140 | grad = nd.dot(one_hot, grad, transpose_a=True) 141 | score_grads.append(grad) 142 | _assign_grads(self._score_params, score_grads) 143 | 144 | 145 | def _leave_one_out_gradient_estimator(h, f, zero_mean_h=False): 146 | """Estimate gradient of f using score function and control variate h. 147 | 148 | Optimal scaling of control variate is given by: a = Cov(h, f) / Var(h). 149 | """ 150 | if h.ndim > f.ndim: 151 | # expand parameter dimension (last dimension summed over in f) 152 | f = nd.expand_dims(f, f.ndim) 153 | grad_f = h * f 154 | if zero_mean_h: 155 | cov_h_f = _leave_one_out_mean(h * grad_f) 156 | var_h = _leave_one_out_mean(h * h) 157 | else: 158 | cov_h_f = _held_out_covariance(h, grad_f) 159 | var_h = _held_out_covariance(h, h) 160 | # sampling zero for low-variance score functions is probable, so add EPSILON! 161 | optimal_a = cov_h_f / (EPSILON + var_h) 162 | if h.ndim == 2: 163 | # If no batch dim: nd.Embedding removes batch dim for batches of size 1 164 | keepdims = True 165 | else: 166 | keepdims = False 167 | return nd.mean(grad_f - optimal_a * h, 0, keepdims=keepdims) 168 | 169 | 170 | def _leave_one_out_mean(a: nd.NDArray) -> nd.NDArray: 171 | """Compute leave-one-out mean of array of shape [n_samples, ...].""" 172 | n_samples = a.shape[0] 173 | assert n_samples >= 256, "Need at least 256 samples for accuracy." 174 | res = (nd.sum(a, 0, keepdims=True) - a) / (n_samples - 2) 175 | assert res.shape == a.shape 176 | return res 177 | 178 | 179 | def _held_out_covariance(x, y): 180 | """Get held-out covariance between x and y in the first dimension.""" 181 | n = x.shape[0] 182 | assert y.shape[0] == n 183 | mean_x = nd.mean(x, 0) 184 | mean_y = nd.mean(y, 0) 185 | res = nd.sum((x - mean_x) * (y - mean_y), 0) 186 | mean_x_held_out = (mean_x - x / n) / (1. - 1. / n) # * n / (n - 1.) 187 | mean_y_held_out = (mean_y - y / n) / (1. - 1. / n) # * n / (n - 1.) 188 | res = res - (x - mean_x_held_out) * (y - mean_y_held_out) 189 | return res / (n - 2.) 190 | 191 | 192 | def _variance(a: nd.NDArray) -> nd.NDArray: 193 | """Compute variance of a of shape [n_samples, ...].""" 194 | mean = nd.mean(a, 0, keepdims=True) 195 | return nd.mean(nd.square(a - mean), 0) 196 | 197 | 198 | def _assign_grads(params: list, grads: list): 199 | """Assign gradients in the context for the parameter.""" 200 | for param, grad in zip(params, grads): 201 | assert param._grad[0].shape == grad.shape 202 | param._grad[0] = grad 203 | param._data[0]._fresh_grad = 1 204 | --------------------------------------------------------------------------------