├── README.md ├── LICENSE ├── estimate_gmm_sklearn.py └── visualization.py /README.md: -------------------------------------------------------------------------------- 1 | ## Tutorial on GMMs 2 | 3 | This code was used in the blog post ["What is a Gaussian Mixture Model (GMM) - 3D Point Cloud Classification Primer"](https://wp.me/p8QHD6-9Z). 4 | 5 | It is composed of three main parts: 6 | 7 | * Generating data 8 | * Fitting the Gaussian Mixture Model 9 | * Visualization 10 | 11 | ### Installation 12 | You will need to have `matplotlib`, `scikit-learn` and ofcourse `numpy` installed. 13 | 14 | The code was tested on Python 3.5.2 on Windows. 15 | ### Usage 16 | Simply run `estimate_gmm_sklearn.py`. 17 | Change the variable `D` to be 2 or 3 for 2D or 3D results respectively. 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017 Yizhak Ben-Shabat 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /estimate_gmm_sklearn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import visualization 3 | from sklearn.mixture import GaussianMixture 4 | 5 | ## Generate synthetic data 6 | N,D = 1000, 2 # number of points and dimenstinality 7 | 8 | if D == 2: 9 | #set gaussian ceters and covariances in 2D 10 | means = np.array([[0.5, 0.0], 11 | [0, 0], 12 | [-0.5, -0.5], 13 | [-0.8, 0.3]]) 14 | covs = np.array([np.diag([0.01, 0.01]), 15 | np.diag([0.025, 0.01]), 16 | np.diag([0.01, 0.025]), 17 | np.diag([0.01, 0.01])]) 18 | elif D == 3: 19 | # set gaussian ceters and covariances in 3D 20 | means = np.array([[0.5, 0.0, 0.0], 21 | [0.0, 0.0, 0.0], 22 | [-0.5, -0.5, -0.5], 23 | [-0.8, 0.3, 0.4]]) 24 | covs = np.array([np.diag([0.01, 0.01, 0.03]), 25 | np.diag([0.08, 0.01, 0.01]), 26 | np.diag([0.01, 0.05, 0.01]), 27 | np.diag([0.03, 0.07, 0.01])]) 28 | n_gaussians = means.shape[0] 29 | 30 | points = [] 31 | for i in range(len(means)): 32 | x = np.random.multivariate_normal(means[i], covs[i], N ) 33 | points.append(x) 34 | points = np.concatenate(points) 35 | 36 | #fit the gaussian model 37 | gmm = GaussianMixture(n_components=n_gaussians, covariance_type='diag') 38 | gmm.fit(points) 39 | 40 | #visualize 41 | if D == 2: 42 | visualization.visualize_2D_gmm(points, gmm.weights_, gmm.means_.T, np.sqrt(gmm.covariances_).T) 43 | elif D == 3: 44 | visualization.visualize_3d_gmm(points, gmm.weights_, gmm.means_.T, np.sqrt(gmm.covariances_).T) 45 | -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib.patches as patches 4 | from mpl_toolkits.mplot3d import Axes3D 5 | import matplotlib.cm as cmx 6 | import os 7 | 8 | def visualize_3d_gmm(points, w, mu, stdev, export=True): 9 | ''' 10 | plots points and their corresponding gmm model in 3D 11 | Input: 12 | points: N X 3, sampled points 13 | w: n_gaussians, gmm weights 14 | mu: 3 X n_gaussians, gmm means 15 | stdev: 3 X n_gaussians, gmm standard deviation (assuming diagonal covariance matrix) 16 | Output: 17 | None 18 | ''' 19 | 20 | n_gaussians = mu.shape[1] 21 | N = int(np.round(points.shape[0] / n_gaussians)) 22 | # Visualize data 23 | fig = plt.figure(figsize=(8, 8)) 24 | axes = fig.add_subplot(111, projection='3d') 25 | axes.set_xlim([-1, 1]) 26 | axes.set_ylim([-1, 1]) 27 | axes.set_zlim([-1, 1]) 28 | plt.set_cmap('Set1') 29 | colors = cmx.Set1(np.linspace(0, 1, n_gaussians)) 30 | for i in range(n_gaussians): 31 | idx = range(i * N, (i + 1) * N) 32 | axes.scatter(points[idx, 0], points[idx, 1], points[idx, 2], alpha=0.3, c=colors[i]) 33 | plot_sphere(w=w[i], c=mu[:, i], r=stdev[:, i], ax=axes) 34 | 35 | plt.title('3D GMM') 36 | axes.set_xlabel('X') 37 | axes.set_ylabel('Y') 38 | axes.set_zlabel('Z') 39 | axes.view_init(35.246, 45) 40 | if export: 41 | if not os.path.exists('images/'): os.mkdir('images/') 42 | plt.savefig('images/3D_GMM_demonstration.png', dpi=100, format='png') 43 | plt.show() 44 | 45 | 46 | def plot_sphere(w=0, c=[0,0,0], r=[1, 1, 1], subdev=10, ax=None, sigma_multiplier=3): 47 | ''' 48 | plot a sphere surface 49 | Input: 50 | c: 3 elements list, sphere center 51 | r: 3 element list, sphere original scale in each axis ( allowing to draw elipsoids) 52 | subdiv: scalar, number of subdivisions (subdivision^2 points sampled on the surface) 53 | ax: optional pyplot axis object to plot the sphere in. 54 | sigma_multiplier: sphere additional scale (choosing an std value when plotting gaussians) 55 | Output: 56 | ax: pyplot axis object 57 | ''' 58 | 59 | if ax is None: 60 | fig = plt.figure() 61 | ax = fig.add_subplot(111, projection='3d') 62 | pi = np.pi 63 | cos = np.cos 64 | sin = np.sin 65 | phi, theta = np.mgrid[0.0:pi:complex(0,subdev), 0.0:2.0 * pi:complex(0,subdev)] 66 | x = sigma_multiplier*r[0] * sin(phi) * cos(theta) + c[0] 67 | y = sigma_multiplier*r[1] * sin(phi) * sin(theta) + c[1] 68 | z = sigma_multiplier*r[2] * cos(phi) + c[2] 69 | cmap = cmx.ScalarMappable() 70 | cmap.set_cmap('jet') 71 | c = cmap.to_rgba(w) 72 | 73 | ax.plot_surface(x, y, z, color=c, alpha=0.2, linewidth=1) 74 | 75 | return ax 76 | 77 | def visualize_2D_gmm(points, w, mu, stdev, export=True): 78 | ''' 79 | plots points and their corresponding gmm model in 2D 80 | Input: 81 | points: N X 2, sampled points 82 | w: n_gaussians, gmm weights 83 | mu: 2 X n_gaussians, gmm means 84 | stdev: 2 X n_gaussians, gmm standard deviation (assuming diagonal covariance matrix) 85 | Output: 86 | None 87 | ''' 88 | n_gaussians = mu.shape[1] 89 | N = int(np.round(points.shape[0] / n_gaussians)) 90 | # Visualize data 91 | fig = plt.figure(figsize=(8, 8)) 92 | axes = plt.gca() 93 | axes.set_xlim([-1, 1]) 94 | axes.set_ylim([-1, 1]) 95 | plt.set_cmap('Set1') 96 | colors = cmx.Set1(np.linspace(0, 1, n_gaussians)) 97 | for i in range(n_gaussians): 98 | idx = range(i * N, (i + 1) * N) 99 | plt.scatter(points[idx, 0], points[idx, 1], alpha=0.3, c=colors[i]) 100 | for j in range(8): 101 | axes.add_patch( 102 | patches.Ellipse(mu[:, i], width=(j+1) * stdev[0, i], height=(j+1) * stdev[1, i], fill=False, color=[0.0, 0.0, 1.0, 1.0/(0.5*j+1)])) 103 | plt.title('GMM') 104 | plt.xlabel('X') 105 | plt.ylabel('Y') 106 | 107 | if export: 108 | if not os.path.exists('images/'): os.mkdir('images/') 109 | plt.savefig('images/2D_GMM_demonstration.png', dpi=100, format='png') 110 | 111 | plt.show() --------------------------------------------------------------------------------