├── .gitignore ├── tf_iris.png ├── iris_sklearn.png ├── requirements.txt ├── README.md ├── sklearn_iris.py └── tf_iris.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /tf_iris.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolov/naive_bayes_tensorflow/HEAD/tf_iris.png -------------------------------------------------------------------------------- /iris_sklearn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolov/naive_bayes_tensorflow/HEAD/iris_sklearn.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bleach==1.5.0 2 | cycler==0.10.0 3 | decorator==4.1.2 4 | enum34==1.1.6 5 | html5lib==0.9999999 6 | ipython==6.2.1 7 | ipython-genutils==0.2.0 8 | jedi==0.11.0 9 | Markdown==2.6.9 10 | matplotlib==2.1.0 11 | numpy==1.13.3 12 | parso==0.1.0 13 | pexpect==4.2.1 14 | pickleshare==0.7.4 15 | prompt-toolkit==1.0.15 16 | protobuf==3.4.0 17 | ptyprocess==0.5.2 18 | Pygments==2.2.0 19 | pyparsing==2.2.0 20 | PyQt5==5.9 21 | python-dateutil==2.6.1 22 | pytz==2017.2 23 | scikit-learn==0.19.1 24 | scipy==1.0.0 25 | simplegeneric==0.8.1 26 | sip==4.19.3 27 | six==1.11.0 28 | sklearn==0.0 29 | tensorflow==1.4.0 30 | tensorflow-tensorboard==0.4.0rc2 31 | traitlets==4.3.2 32 | wcwidth==0.1.7 33 | Werkzeug==0.12.2 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Naive Bayes classifiers in TensorFlow 2 | 3 | 4 | 5 | A simple [Naive Bayes classifier]() in TensorFlow 1.4. It's a tidy 6 | demonstration of [`tf.distributions`](https://www.tensorflow.org/api_docs/python/tf/distributions) and some unusual tensor operations. 7 | 8 | For more information, you can read the [blog post](http://nicolovaligi.com/naive-bayes-tensorflow.html). 9 | 10 | ## Getting started 11 | 12 | Prepare the Python environment: 13 | 14 | ``` 15 | # Create a new virtualenv 16 | mkvirtualenv env 17 | source env/bin/activate 18 | # Install requirements 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | And run the classifier: 23 | 24 | ``` 25 | python tf_iris.py 26 | ``` 27 | -------------------------------------------------------------------------------- /sklearn_iris.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Flower classification on the Iris dataset using a Naive Bayes 5 | classifier and TensorFlow. 6 | 7 | For more info: http://nicolovaligi.com/naive-bayes-tensorflow.html 8 | """ 9 | 10 | from IPython import embed 11 | import numpy as np 12 | from matplotlib import pyplot as plt 13 | from matplotlib import colors 14 | 15 | from sklearn.naive_bayes import GaussianNB 16 | from sklearn import datasets 17 | 18 | if __name__ == '__main__': 19 | iris = datasets.load_iris() 20 | # Only take the first two features 21 | X = iris.data[:, :2] 22 | y = iris.target 23 | 24 | x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 25 | y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 26 | 27 | # Fit the Naive Bayes classifier 28 | gnb = GaussianNB() 29 | gnb.fit(X, y) 30 | 31 | # Classify a grid of points 32 | xx, yy = np.meshgrid(np.linspace(x_min, x_max, 30), 33 | np.linspace(y_min, y_max, 30)) 34 | Z = gnb.predict_proba(np.c_[xx.ravel(), yy.ravel()]) 35 | Z1 = Z[:, 1].reshape(xx.shape) 36 | Z2 = Z[:, 2].reshape(xx.shape) 37 | 38 | # Plot 39 | fig = plt.figure(figsize=(5, 3.75)) 40 | ax = fig.add_subplot(111) 41 | ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1, 42 | edgecolor='k') 43 | # Swap signs to make the contour dashed (MPL default) 44 | ax.contour(xx, yy, -Z1, [-0.5], colors='k') 45 | ax.contour(xx, yy, -Z2, [-0.5], colors='k') 46 | 47 | # Plot formatting 48 | ax.set_xlabel('Sepal length') 49 | ax.set_ylabel('Sepal width') 50 | ax.set_title('sklearn decision boundary') 51 | ax.set_xlim(x_min, x_max) 52 | ax.set_ylim(y_min, y_max) 53 | ax.set_xticks(()) 54 | ax.set_yticks(()) 55 | 56 | plt.tight_layout() 57 | fig.savefig('iris_sklearn.png', bbox_inches='tight') 58 | -------------------------------------------------------------------------------- /tf_iris.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Flower classification on the Iris dataset using a Naive Bayes 5 | classifier and TensorFlow. 6 | 7 | For more info: http://nicolovaligi.com/naive-bayes-tensorflow.html 8 | """ 9 | 10 | from IPython import embed 11 | from matplotlib import colors 12 | from matplotlib import pyplot as plt 13 | from sklearn import datasets 14 | import numpy as np 15 | import tensorflow as tf 16 | from sklearn.utils.fixes import logsumexp 17 | import numpy as np 18 | 19 | 20 | class TFNaiveBayesClassifier: 21 | dist = None 22 | 23 | def fit(self, X, y): 24 | # Separate training points by class (nb_classes * nb_samples * nb_features) 25 | unique_y = np.unique(y) 26 | points_by_class = np.array([ 27 | [x for x, t in zip(X, y) if t == c] 28 | for c in unique_y]) 29 | 30 | # Estimate mean and variance for each class / feature 31 | # shape: nb_classes * nb_features 32 | mean, var = tf.nn.moments(tf.constant(points_by_class), axes=[1]) 33 | 34 | # Create a 3x2 univariate normal distribution with the 35 | # known mean and variance 36 | self.dist = tf.distributions.Normal(loc=mean, scale=tf.sqrt(var)) 37 | 38 | def predict(self, X): 39 | assert self.dist is not None 40 | nb_classes, nb_features = map(int, self.dist.scale.shape) 41 | 42 | # Conditional probabilities log P(x|c) with shape 43 | # (nb_samples, nb_classes) 44 | cond_probs = tf.reduce_sum( 45 | self.dist.log_prob( 46 | tf.reshape( 47 | tf.tile(X, [1, nb_classes]), [-1, nb_classes, nb_features])), 48 | axis=2) 49 | 50 | # uniform priors 51 | priors = np.log(np.array([1. / nb_classes] * nb_classes)) 52 | 53 | # posterior log probability, log P(c) + log P(x|c) 54 | joint_likelihood = tf.add(priors, cond_probs) 55 | 56 | # normalize to get (log)-probabilities 57 | norm_factor = tf.reduce_logsumexp( 58 | joint_likelihood, axis=1, keep_dims=True) 59 | log_prob = joint_likelihood - norm_factor 60 | # exp to get the actual probabilities 61 | return tf.exp(log_prob) 62 | 63 | 64 | if __name__ == '__main__': 65 | iris = datasets.load_iris() 66 | # Only take the first two features 67 | X = iris.data[:, :2] 68 | y = iris.target 69 | 70 | tf_nb = TFNaiveBayesClassifier() 71 | tf_nb.fit(X, y) 72 | 73 | # Create a regular grid and classify each point 74 | x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 75 | y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 76 | xx, yy = np.meshgrid(np.linspace(x_min, x_max, 30), 77 | np.linspace(y_min, y_max, 30)) 78 | s = tf.Session() 79 | Z = s.run(tf_nb.predict(np.c_[xx.ravel(), yy.ravel()])) 80 | # Extract probabilities of class 2 and 3 81 | Z1 = Z[:, 1].reshape(xx.shape) 82 | Z2 = Z[:, 2].reshape(xx.shape) 83 | 84 | # Plot 85 | fig = plt.figure(figsize=(5, 3.75)) 86 | ax = fig.add_subplot(111) 87 | 88 | ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1, 89 | edgecolor='k') 90 | # Swap signs to make the contour dashed (MPL default) 91 | ax.contour(xx, yy, -Z1, [-0.5], colors='k') 92 | ax.contour(xx, yy, -Z2, [-0.5], colors='k') 93 | 94 | ax.set_xlabel('Sepal length') 95 | ax.set_ylabel('Sepal width') 96 | ax.set_title('TensorFlow decision boundary') 97 | ax.set_xlim(x_min, x_max) 98 | ax.set_ylim(y_min, y_max) 99 | ax.set_xticks(()) 100 | ax.set_yticks(()) 101 | 102 | plt.tight_layout() 103 | fig.savefig('tf_iris.png', bbox_inches='tight') 104 | --------------------------------------------------------------------------------