├── .gitignore ├── README.md ├── notebooks └── KMNClass.ipynb ├── requirements.txt └── src ├── __init__.py ├── kmn.py └── test └── test_kmn.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *,cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | 59 | # DotEnv configuration 60 | .env 61 | 62 | # Database 63 | *.db 64 | *.rdb 65 | 66 | # Pycharm 67 | .idea 68 | 69 | # IPython NB Checkpoints 70 | .ipynb_checkpoints/ 71 | 72 | # exclude data from source control by default 73 | /data/ 74 | 75 | # History files 76 | .Rhistory 77 | .Rapp.history 78 | 79 | # Session Data files 80 | .RData 81 | 82 | # Example code in package build process 83 | *-Ex.R 84 | 85 | # Output files from R CMD build 86 | /*.tar.gz 87 | 88 | # Output files from R CMD check 89 | /*.Rcheck/ 90 | 91 | # RStudio files 92 | .Rproj.user/ 93 | 94 | # produced vignettes 95 | vignettes/*.html 96 | vignettes/*.pdf 97 | 98 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 99 | .httr-oauth 100 | 101 | # knitr and R markdown default cache directories 102 | /*_cache/ 103 | /cache/ 104 | 105 | # Temporary files created by R markdown 106 | *.utf8.md 107 | *.knit.md 108 | 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KMN 2 | 3 | A Kernel Mixture Network implementation based on [Ambrogioni et al. 2017](https://arxiv.org/abs/1705.07111) with some minor tweaks 4 | (kernel center clustering and trainable scales). 5 | We provide notebooks with both a low-level implementation in [TensorFlow](https://www.tensorflow.org/), 6 | as well as a plug-and-play estimator class with [Keras](https://keras.io/) and [Edward](edwardlib.org). 7 | For more technical details, see Jan van der Vegt's blog post on [Kernel Mixture Networks](https://janvdvegt.github.io/2017/06/07/Kernel-Mixture-Networks.html) 8 | or ["How to obtain advanced probabilistic predictions for your data science use case"](http://www.bigdatarepublic.nl/kernel-mixture-networks/) 9 | for a high-level summary. 10 | 11 | # KernelMixtureNetwork Class 12 | 13 | This class API allows you to plug in your own network together with the placeholder for the input, and uses kernels based on your training data to condition probability densities based on your input. What this class also does what is not discussed in the paper is allow you to train the bandwidth of your kernels. Currently only Gaussian kernels are supported but the class is easily extended. It is not meant as a package but just a reference on how to use this technique using TensorFlow, Edward and Keras. 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-learn 3 | tensorflow 4 | keras 5 | edward 6 | pandas 7 | matplotlib -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/janvdvegt/KernelMixtureNetwork/b5fe2d81a6cd90bf3aeb1f2aaf42f77b16b4f024/src/__init__.py -------------------------------------------------------------------------------- /src/kmn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.cluster import KMeans, AgglomerativeClustering 3 | import pandas as pd 4 | from sklearn.base import BaseEstimator 5 | from edward.models import Categorical, Mixture, Normal 6 | from keras.layers import Dense, Dropout 7 | import math 8 | import edward as ed 9 | import tensorflow as tf 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def sample_center_points(y, method='all', k=100, keep_edges=False): 14 | """ 15 | function to define kernel centers with various downsampling alternatives 16 | """ 17 | 18 | # make sure y is 1D 19 | y = y.ravel() 20 | 21 | # keep all points as kernel centers 22 | if method == 'all': 23 | return y 24 | 25 | # retain outer points to ensure expressiveness at the target borders 26 | if keep_edges: 27 | y = np.sort(y) 28 | centers = np.array([y[0], y[-1]]) 29 | y = y[1:-1] 30 | # adjust k such that the final output has size k 31 | k -= 2 32 | else: 33 | centers = np.empty(0) 34 | 35 | if method == 'random': 36 | cluster_centers = np.random.choice(y, k, replace=False) 37 | 38 | # iteratively remove part of pairs that are closest together until everything is at least 'd' apart 39 | elif method == 'distance': 40 | raise NotImplementedError 41 | 42 | # use 1-D k-means clustering 43 | elif method == 'k_means': 44 | model = KMeans(n_clusters=k, n_jobs=-2) 45 | model.fit(y.reshape(-1, 1)) 46 | cluster_centers = model.cluster_centers_ 47 | 48 | # use agglomerative clustering 49 | elif method == 'agglomerative': 50 | model = AgglomerativeClustering(n_clusters=k, linkage='complete') 51 | model.fit(y.reshape(-1, 1)) 52 | labels = pd.Series(model.labels_, name='label') 53 | y_s = pd.Series(y, name='y') 54 | df = pd.concat([y_s, labels], axis=1) 55 | cluster_centers = df.groupby('label')['y'].mean().values 56 | 57 | else: 58 | raise ValueError("unknown method '{}'".format(method)) 59 | 60 | return np.append(centers, cluster_centers) 61 | 62 | 63 | class KernelMixtureNetwork(BaseEstimator): 64 | 65 | def __init__(self, n_samples=10, center_sampling_method='k_means', n_centers=20, keep_edges=False, 66 | init_scales='default', estimator=None, X_ph=None, train_scales=False): 67 | """ 68 | Main class for Kernel Mixture Network 69 | 70 | Args: 71 | center_sampling_method: String that describes the method to use for finding kernel centers 72 | n_centers: Number of kernels to use in the output 73 | keep_edges: Keep the extreme y values as center to keep expressiveness 74 | init_scales: List or scalar that describes (initial) values of bandwidth parameter 75 | estimator: Keras or tensorflow network that ends with a dense layer to place kernel mixture output on top off, 76 | if None use a standard 15 -> 15 Dense network 77 | X_ph: Placeholder for input to your custom estimator, currently only supporting one input placeholder, 78 | but should be easy to extend to a list of placeholders 79 | train_scales: Boolean that describes whether or not to make the scales trainable 80 | n_samples: Determine how many samples to return 81 | """ 82 | 83 | self.sess = ed.get_session() 84 | self.inference = None 85 | 86 | self.estimator = estimator 87 | self.X_ph = X_ph 88 | 89 | self.n_samples = n_samples 90 | self.center_sampling_method = center_sampling_method 91 | self.n_centers = n_centers 92 | self.keep_edges = keep_edges 93 | 94 | self.train_loss = np.empty(0) 95 | self.test_loss = np.empty(0) 96 | 97 | if init_scales == 'default': 98 | init_scales = np.array([1]) 99 | 100 | # Transform scales so that the softplus will result in passed init_scales 101 | self.init_scales = [math.log(math.exp(s) - 1) for s in init_scales] 102 | self.n_scales = len(self.init_scales) 103 | self.train_scales = train_scales 104 | 105 | self.fitted = False 106 | 107 | def fit(self, X, y, n_epoch, **kwargs): 108 | """ 109 | build and train model 110 | """ 111 | # define the full model 112 | self._build_model(X, y) 113 | 114 | # setup inference procedure 115 | self.inference = ed.MAP(data={self.mixtures: self.y_ph}) 116 | self.inference.initialize(var_list=tf.trainable_variables(), n_iter=n_epoch) 117 | tf.global_variables_initializer().run() 118 | 119 | # train the model 120 | self.partial_fit(X, y, n_epoch=n_epoch, **kwargs) 121 | self.fitted = True 122 | 123 | def partial_fit(self, X, y, n_epoch=1, eval_set=None): 124 | """ 125 | update model 126 | """ 127 | print("fitting model") 128 | 129 | # loop over epochs 130 | for i in range(n_epoch): 131 | 132 | # run inference, update trainable variables of the model 133 | info_dict = self.inference.update(feed_dict={self.X_ph: X, self.y_ph: y}) 134 | 135 | train_loss = info_dict['loss'] / len(y) 136 | self.train_loss = np.append(self.train_loss, -train_loss) 137 | 138 | if eval_set is not None: 139 | X_test, y_test = eval_set 140 | test_loss = self.sess.run(self.inference.loss, feed_dict={self.X_ph: X_test, self.y_ph: y_test}) / len(y_test) 141 | self.test_loss = np.append(self.test_loss, -test_loss) 142 | 143 | # only print progress for the initial fit, not for additional updates 144 | if not self.fitted: 145 | self.inference.print_progress(info_dict) 146 | 147 | print("mean log-loss train: {:.3f}".format(train_loss)) 148 | if eval_set is not None: 149 | print("man log-loss test: {:.3f}".format(test_loss)) 150 | 151 | print("optimal scales: {}".format(self.sess.run(self.scales))) 152 | 153 | def predict(self, X, y): 154 | """ 155 | likelihood of a given target value 156 | """ 157 | return self.sess.run(self.likelihoods, feed_dict={self.X_ph: X, self.y_ph: y}) 158 | 159 | def predict_density(self, X, y=None, resolution=100): 160 | """ 161 | conditional density over a predefined grid of target values 162 | """ 163 | if y is None: 164 | max_scale = np.max(self.sess.run(self.scales)) 165 | y = np.linspace(self.y_min - 2.5 * max_scale, self.y_max + 2.5 * max_scale, num=resolution) 166 | 167 | return self.sess.run(self.densities, feed_dict={self.X_ph: X, self.y_grid_ph: y}) 168 | 169 | def sample(self, X): 170 | """ 171 | sample from the conditional mixture distributions 172 | """ 173 | return self.sess.run(self.samples, feed_dict={self.X_ph: X}) 174 | 175 | def score(self, X, y): 176 | """ 177 | return mean log likelihood 178 | """ 179 | likelihoods = self.predict(X, y) 180 | return np.log(likelihoods).mean() 181 | 182 | def _build_model(self, X, y): 183 | """ 184 | implementation of the KMN 185 | """ 186 | # create a placeholder for the target 187 | self.y_ph = y_ph = tf.placeholder(tf.float32, [None]) 188 | self.n_sample_ph = tf.placeholder(tf.int32, None) 189 | 190 | # store feature dimension size for placeholder 191 | self.n_features = X.shape[1] 192 | 193 | # if no external estimator is provided, create a default neural network 194 | if self.estimator is None: 195 | self.X_ph = tf.placeholder(tf.float32, [None, self.n_features]) 196 | # two dense hidden layers with 15 nodes each 197 | x = Dense(15, activation='relu')(self.X_ph) 198 | x = Dense(15, activation='relu')(x) 199 | self.estimator = x 200 | 201 | # get batch size 202 | self.batch_size = tf.shape(self.X_ph)[0] 203 | 204 | # locations of the gaussian kernel centers 205 | n_locs = self.n_centers 206 | self.locs = locs = sample_center_points(y, method=self.center_sampling_method, k=n_locs, keep_edges=self.keep_edges) 207 | self.locs_array = locs_array = tf.unstack(tf.transpose(tf.multiply(tf.ones((self.batch_size, n_locs)), locs))) 208 | 209 | # scales of the gaussian kernels 210 | self.scales = scales = tf.nn.softplus(tf.Variable(self.init_scales, dtype=tf.float32, trainable=self.train_scales)) 211 | self.scales_array = scales_array = tf.unstack(tf.transpose(tf.multiply(tf.ones((self.batch_size, self.n_scales)), scales))) 212 | 213 | # kernel weights, as output by the neural network 214 | self.weights = weights = Dense(n_locs * self.n_scales, activation='softplus')(self.estimator) 215 | 216 | # mixture distributions 217 | self.cat = cat = Categorical(logits=weights) 218 | self.components = components = [Normal(loc=loc, scale=scale) for loc in locs_array for scale in scales_array] 219 | self.mixtures = mixtures = Mixture(cat=cat, components=components, value=tf.zeros_like(y_ph)) 220 | 221 | # tensor to store samples 222 | #self.samples = mixtures.sample(sample_shape=self.n_samples) 223 | self.samples = mixtures.sample() 224 | 225 | # store minmax of training target values for a sensible default grid for self.predict_density() 226 | #self.y_range = y.max() - y.min() 227 | #self.y_min = y.min() - 0.1 * self.y_range 228 | #self.y_max = y.max() + 0.1 * self.y_range 229 | self.y_min = y.min() 230 | self.y_max = y.max() 231 | 232 | # placeholder for the grid 233 | self.y_grid_ph = y_grid_ph = tf.placeholder(tf.float32) 234 | # tensor to store grid point densities 235 | self.densities = tf.transpose(mixtures.prob(tf.reshape(y_grid_ph, (-1, 1)))) 236 | 237 | # tensor to compute likelihoods 238 | self.likelihoods = mixtures.prob(y_ph) 239 | 240 | def plot_loss(self): 241 | """ 242 | plot train loss and optionally test loss over epochs 243 | source: http://edwardlib.org/tutorials/mixture-density-network 244 | """ 245 | # new figure 246 | fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(12, 3)) 247 | 248 | # plot train loss 249 | plt.plot(np.arange(len(self.train_loss)), self.train_loss, label='Train') 250 | 251 | if len(self.test_loss) > 0: 252 | # plot test loss 253 | plt.plot(np.arange(len(self.test_loss)), self.test_loss, label='Test') 254 | 255 | plt.legend(fontsize=20) 256 | plt.xlabel('epoch', fontsize=15) 257 | plt.ylabel('mean negative log-likelihood', fontsize=15) 258 | plt.show() 259 | 260 | return fig, axes 261 | -------------------------------------------------------------------------------- /src/test/test_kmn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from sklearn.model_selection import train_test_split 3 | import numpy as np 4 | from src.kmn import KernelMixtureNetwork 5 | from keras.layers import Dense 6 | 7 | 8 | class TestKernelMixtureNetwork(tf.test.TestCase): 9 | 10 | def create_dataset(self, n=5000): 11 | """ 12 | function to create dummy data 13 | source: http://edwardlib.org/tutorials/mixture-density-network 14 | """ 15 | y_data = np.random.uniform(-10.5, 10.5, n) 16 | r_data = np.random.normal(size=n) # random noise 17 | x_data = np.sin(0.75 * y_data) * 7.0 + y_data * 0.5 + r_data * 1.0 18 | x_data = x_data.reshape((n, 1)) 19 | 20 | return train_test_split(x_data, y_data, random_state=42) 21 | 22 | def test_run(self): 23 | """test case with simulated data and network training with default settings""" 24 | 25 | X_train, X_test, y_train, y_test = self.create_dataset() 26 | 27 | kmn = KernelMixtureNetwork() 28 | 29 | self.assertTrue(isinstance(kmn, object)) 30 | 31 | kmn.fit(X_train, y_train, n_epoch=100, eval_set=(X_test, y_test)) 32 | 33 | # TODO: make this test deterministic! 34 | train_loss1 = kmn.train_loss[-1] 35 | self.assertTrue(train_loss1 < 2.) 36 | self.assertTrue(kmn.test_loss[-1] < 3.) 37 | 38 | kmn.partial_fit(X_train, y_train, n_epoch=200, eval_set=(X_test, y_test)) 39 | self.assertTrue(kmn.train_loss[-1] <= train_loss1) 40 | 41 | likelihoods = kmn.predict(X_test, y_test) 42 | mean_loglik = np.log(likelihoods).mean() 43 | 44 | self.assertTrue(mean_loglik < 3.) 45 | 46 | score = kmn.score(X_test, y_test) 47 | self.assertTrue(abs(mean_loglik - score) < 0.01) 48 | 49 | kmn.sess.close() 50 | 51 | # TODO: 52 | # test for sample() 53 | # test for predict_density() 54 | # test for plot_loss() 55 | 56 | def test_external_estimator(self): 57 | """test case with simulated data and network training with an external estimator""" 58 | 59 | X_train, X_test, y_train, y_test = self.create_dataset() 60 | 61 | kmn1 = KernelMixtureNetwork() 62 | kmn1.fit(X_train, y_train, n_epoch=100) 63 | kmn1.sess.close() 64 | 65 | X_ph = tf.placeholder(tf.float32, [None, X_train.shape[1]]) 66 | x = Dense(15, activation='relu')(X_ph) 67 | neural_network = Dense(15, activation='relu')(x) 68 | 69 | kmn2 = KernelMixtureNetwork(estimator=neural_network, X_ph=X_ph) 70 | kmn2.fit(X_train, y_train, n_epoch=200) 71 | kmn2.sess.close() 72 | 73 | self.assertTrue(abs(kmn1.train_loss[-1] - kmn2.train_loss[-1]) < 0.1) 74 | 75 | def test_sample_center_points(self): 76 | pass 77 | 78 | # TODO: 79 | # test sample_center_points() with all different methods 80 | 81 | if __name__ == '__main__': 82 | tf.test.main() 83 | 84 | --------------------------------------------------------------------------------