├── .gitignore ├── GMM from scratch.ipynb ├── README.md └── images ├── animated_GMM new.gif ├── eq1.png ├── eq2.png ├── equatios3.png ├── model traininng.png ├── output_39_0.png ├── output_6_0.png ├── output_9_0.png └── summary.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore Python bytecode files 2 | *.pyc 3 | 4 | # Ignore Python cache directories 5 | __pycache__/ 6 | *.pytest_cache/ 7 | .mypy_cache/ 8 | 9 | # Ignore IPython notebook checkpoints 10 | *.ipynb_checkpoints/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gaussian Mixture Model Clearly Explained 2 | ### _The only guide you need to learn everything about GMM_ 3 | 4 | When we talk about Gaussian Mixture Model (later, this will be denoted as GMM in this article), it's essential to know how the KMeans algorithm works. Because GMM is quite similar to the KMeans, more likely it's a probabilistic version of KMeans. This probabilistic feature allows GMM to be applied to many complex problems that KMeans can't fit into. 5 | 6 | In summary, KMeans have below limitations, 7 | 8 | 1. It assumed that the clusters were spherical and equally sized, which is not valid in most real-world scenarios. 9 | 2. It's a hard clustering method. Meaning each data point is assigned to a single cluster. 10 | 11 | Due to these limitations, we should know alternatives for KMeans when working on our machine learning projects. In this article, we will explore one of the best alternatives for KMeans clustering, called the Gaussian Mixture Model. 12 | 13 | Throughout this article, we will be covering the below points. 14 | 15 | 1. How Gaussian Mixture Model (GMM) algorithm works — in plain English. 16 | 2. Mathematics behind GMM. 17 | 3. Implement GMM using Python from scratch. 18 | 19 | ## How Gaussian Mixture Model (GMM) algorithm works — in plain English 20 | 21 | 22 | How Gaussian Mixture Model (GMM) algorithm works — in plain English 23 | As I have mentioned earlier, we can call GMM probabilistic KMeans because the starting point and training process of the KMeans and GMM are the same. However, KMeans uses a distance-based approach, and GMM uses a probabilistic approach. There is one primary assumption in GMM: the dataset consists of multiple Gaussians, in other words, a mixture of the gaussian. 24 | 25 | 26 | ```python 27 | #import requred libraries 28 | import seaborn as sns 29 | import matplotlib.pyplot as plt 30 | import numpy as np 31 | import numpy as np 32 | import matplotlib.pyplot as plt 33 | from mpl_toolkits.mplot3d import Axes3D 34 | from scipy.stats import multivariate_normal 35 | from scipy.stats import norm 36 | import warnings 37 | import random 38 | 39 | warnings.filterwarnings('ignore') 40 | ``` 41 | 42 | 43 | ```python 44 | # Generate some data with multiple modes 45 | data1 = np.random.normal(0, 1, 1000) 46 | data2 = np.random.normal(5, 1, 1000) 47 | 48 | # Plot the data using seaborn's distplot function 49 | sns.distplot(data1, kde=True, hist=True, bins=100, color='b', hist_kws={'alpha': 0.5}) 50 | sns.distplot(data2, kde=True, hist=True, bins=100, color='r', hist_kws={'alpha': 0.5}) 51 | 52 | # Add a legend 53 | plt.legend(['Data 1', 'Data 2']) 54 | 55 | # Show the plot 56 | plt.show() 57 | ``` 58 | 59 | 60 | 61 | ![png](images/output_6_0.png) 62 | 63 | 64 | 65 | The above kind of distribution is often called multi-model distribution. Each peak represents the different gaussian distribution or the cluster in our dataset. But the question is, 66 | 67 | #### _how do we estimate these distributions?_ 68 | 69 | Before answering this question, let's create some gaussian distribution first. Please note here I am generating multivariate normal distribution; it's a higher dimensional extension of the univariate normal distribution. 70 | 71 | Let's define the mean and covariance of our data points. Using mean and covariance, we can generate the distribution as follows. 72 | 73 | 74 | ```python 75 | # Set the mean and covariance 76 | mean1 = [0, 0] 77 | mean2 = [2, 0] 78 | cov1 = [[1, .7], [.7, 1]] 79 | cov2 = [[.5, .4], [.4, .5]] 80 | 81 | # Generate data from the mean and covariance 82 | data1 = np.random.multivariate_normal(mean1, cov1, size=1000) 83 | data2 = np.random.multivariate_normal(mean2, cov2, size=1000) 84 | ``` 85 | 86 | 87 | ```python 88 | plt.figure(figsize=(10,6)) 89 | 90 | plt.scatter(data1[:,0],data1[:,1]) 91 | plt.scatter(data2[:,0],data2[:,1]) 92 | 93 | sns.kdeplot(data1[:, 0], data1[:, 1], levels=20, linewidth=10, color='k', alpha=0.2) 94 | sns.kdeplot(data2[:, 0], data2[:, 1], levels=20, linewidth=10, color='k', alpha=0.2) 95 | 96 | plt.grid(False) 97 | plt.show() 98 | ``` 99 | 100 | 101 | 102 | ![png](images/output_9_0.png) 103 | 104 | 105 | 106 | As you can see here, we generated random gaussian distribution using mean and covariance matrices. What about reversing this process? That's what exactly GMM is doing. But how? 107 | 108 | _Because, in the beginning, we didn’t have any insights about clusters nor their associated mean and covariance matrices_ 109 | 110 | Well, It happens according to the below steps, 111 | 112 | 1. Decide the number of clusters (to decide this, we can use domain knowledge or other methods such as BIC/AIC) for the given dataset. Assume that we have 1000 data points, and we set the number of groups as 2. 113 | 2. Initiate mean, covariance, and weight parameter per cluster. (we will explore more about this in a later section) 114 | 3. Use the Expectation Maximization algorithm to do the following, 115 | - Expectation Step (E step): Calculate the probability of each data point belonging to each data point, then evaluate the likelihood function using the current estimate for the parameters 116 | - Maximization step (M step): Update the previous mean, covariance, and weight parameters to maximize the expected likelihood found in the E step 117 | - Repeat these steps until the model converges. 118 | 119 | With this information, I am concluding the no-math explanation of the GMM algorithm. 120 | 121 | ## Mathematics behind GMM 122 | 123 | The core of GMM lies within Expectation Maximization(EM) algorithm described in the previous section. 124 | 125 | Let's demonstrate the EM algorithm in the sense of GMM. 126 | 127 | __Step 01: Initialize mean, covariance and weight parameters__ 128 | 129 | 1. mean (μ): initialize randomly. 130 | 2. covariance (Σ): initialize randomly 131 | 3. weight (mixing coefficients) (π): fraction per class refers to the likelihood that a particular data point belongs to each class. In the beginning, this will be equal for all clusters. Assume that we fit a GMM with three components. In this case weight parameter might be set to 1/3 for each component, resulting in a probability distribution of (1/3, 1/3, 1/3). 132 | 133 | __Step 02: Expectation Step (E step)__ 134 | 135 | For each data point 𝑥𝑖: 136 | Calculate the probability that the data point belongs to cluster (𝑐) using the below equation. k is the number of distributions we are supposed to find. 137 | 138 | ![images/eq1.png](images/eq1.png) 139 | 140 | Where 𝜋_𝑐 is the mixing coefficient (sometimes called weight) for the Gaussian distribution c, which was initialized in the previous stage, and 𝑁(𝒙 | 𝝁,𝚺) describes the probability density function (PDF) of a Gaussian distribution with mean 𝜇 and covariance Σ with respect to data point x; We can denote it as below. 141 | 142 | ![eq2.png](images/eq2.png) 143 | 144 | The E-step computes these probabilities using the current estimates of the model's parameters. These probabilities are typically referred to as the "responsibilities" of the Gaussian distributions. They are represented by the variables r_ic, where i is the index of the data point, and c is the index of the Gaussian distribution. The responsibility measures how much the c-th Gaussian distribution is responsible for generating the i-th data point. Conditional probability is used here, more specifically, Bayes theorem. 145 | 146 | Let's take a simple example. Assume we have 100 data points and need to cluster them into two groups. We can write r_ic(i=20,c=1) as follows. Where i represents the data point's index, and c represents the index of the cluster we are considering. 147 | 148 | Please note at the beginning, 𝜋_𝑐 initialized to equal for each cluster c = 1,2,3,..,k. In our case, 𝜋_1 = 𝜋_2 = 1/2. 149 | 150 | ![image.png](attachment:4b033967-6cc4-46ba-8588-13f3c72524d2.png) 151 | 152 | The result of the E-step is a set of responsibilities for each data point and each Gaussian distribution in the mixture model. These responsibilities are used in the M-step to update the estimates of the model's parameters. 153 | 154 | __Step 03: Maximization Step (M step)__ 155 | 156 | In this step, the algorithm uses the responsibilities of the Gaussian distributions (computed in the E-step) to update the estimates of the model's parameters. 157 | 158 | The M-step updates the estimates of the parameters as follows: 159 | 160 | ![equatios3.png](images/equatios3.png) 161 | 162 | 1. Update the πc (mixing coefficients) using equation 4 above. 163 | 2. Update the μc using equation number 5 above. 164 | 165 | 3. Then update the Σc using the 6th equation. 166 | 167 | Additional Fact: 168 | 169 | _πc can be considered equivalent to the fraction of points allocated to 𝑐 because numerator Σ_𝑖 *𝑟_𝑖𝑐 represents the likelihood of the data point belonging to the gaussian c. If we assume we have 3 clusters and 𝑖-th data point belongs to cluster 1, we can write the related vector as [0.97,0.02,0.01]. If we sum these vectors for each data point, the result vector is approximately equal to the number of data points per cluster._ 170 | 171 | This updated estimate is used in the next E-step to compute new responsibilities for the data points. 172 | 173 | So on and so forth, this process will repeat until algorithm convergence, typically achieved when the model parameters do not change significantly from one iteration to the next. 174 | 175 | Lots of ugly and complex equations, right? :) 176 | 177 | ### _Let’s summarize the above facts into one simple diagram,_ 178 | 179 | ![summary.png](images/summary.png) 180 | 181 | Don't worry; when it comes to coding, it will be one line per each equation. Let's start to implement GMM from scratch using Python. 182 | 183 | ## Implement GMM using Python from scratch. 184 | 185 | ![animated_GMM new.gif](images/animated_GMM%20new.gif) 186 | 187 | First thing first, let's create a fake dataset. In this section, I will implement GMM for the 1-D dataset. 188 | 189 | 190 | ```python 191 | n_samples = 100 192 | mu1, sigma1 = -5, 1.2 193 | mu2, sigma2 = 5, 1.8 194 | mu3, sigma3 = 0, 1.6 195 | 196 | x1 = np.random.normal(loc = mu1, scale = np.sqrt(sigma1), size = n_samples) 197 | x2 = np.random.normal(loc = mu2, scale = np.sqrt(sigma2), size = n_samples) 198 | x3 = np.random.normal(loc = mu3, scale = np.sqrt(sigma3), size = n_samples) 199 | 200 | X = np.concatenate((x1,x2,x3)) 201 | ``` 202 | 203 | 204 | ```python 205 | def plot_pdf(mu,sigma,label,alpha=0.5,linestyle='k--',density=True,color='green'): 206 | """ 207 | Plot 1-D data and its PDF curve. 208 | 209 | Parameters 210 | ---------- 211 | X : array-like, shape (n_samples,) 212 | The input data. 213 | """ 214 | # Compute the mean and standard deviation of the data 215 | 216 | # Plot the data 217 | 218 | X = norm.rvs(mu, sigma, size=1000) 219 | 220 | plt.hist(X, bins=50, density=density, alpha=alpha,label=label,color=color) 221 | 222 | # Plot the PDF 223 | x = np.linspace(X.min(), X.max(), 1000) 224 | y = norm.pdf(x, mu, sigma) 225 | plt.plot(x, y, linestyle) 226 | 227 | ``` 228 | 229 | And plot the generated data as follows. Please note that instead of plotting the data itself, I have plotted the probability density of each sample. 230 | 231 | 232 | ```python 233 | plot_pdf(mu1,sigma1,label=r"$\mu={} \ ; \ \sigma={}$".format(mu1,sigma1),color=None) 234 | plot_pdf(mu2,sigma2,label=r"$\mu={} \ ; \ \sigma={}$".format(mu2,sigma2),color=None) 235 | plot_pdf(mu3,sigma3,label=r"$\mu={} \ ; \ \sigma={}$".format(mu3,sigma3),color=None) 236 | plt.title("Original Distribution") 237 | plt.legend() 238 | plt.show() 239 | ``` 240 | 241 | 242 | 243 | ![png](images/output_39_0.png) 244 | 245 | 246 | 247 | Let's build each step described in the previous section, 248 | 249 | __Step 01: Initialize mean, covariance, and weights__ 250 | 251 | 252 | ```python 253 | def random_init(n_compenents): 254 | 255 | """Initialize means, weights and variance randomly""" 256 | 257 | pi = np.ones((n_compenents)) / n_compenents 258 | means = np.random.choice(X, n_compenents) 259 | variances = np.random.random_sample(size=n_compenents) 260 | plot_pdf(means[0],variances[0],'Random Init 01',) 261 | plot_pdf(means[1],variances[1],'Random Init 02',color='blue') 262 | plot_pdf(means[2],variances[2],'Random Init 03',color='orange') 263 | 264 | plt.title("Random Initialization") 265 | 266 | plt.legend() 267 | plt.show() 268 | 269 | return means,variances,pi 270 | ``` 271 | 272 | __Step 02: Expectation Step (E step)__ 273 | 274 | 275 | ```python 276 | def step_expectation(X,n_components,means,variances): 277 | """E Step 278 | 279 | Parameters 280 | ---------- 281 | X : array-like, shape (n_samples,) 282 | The data. 283 | n_components : int 284 | The number of clusters 285 | means : array-like, shape (n_components,) 286 | The means of each mixture component. 287 | variances : array-like, shape (n_components,) 288 | The variances of each mixture component. 289 | 290 | Returns 291 | ------- 292 | weights : array-like, shape (n_components,n_samples) 293 | """ 294 | weights = np.zeros((n_components,len(X))) 295 | for j in range(n_components): 296 | weights[j,:] = norm(loc=means[j],scale=np.sqrt(variances[j])).pdf(X) 297 | return weights 298 | ``` 299 | 300 | After this function, we covered the first two equations we discussed in E Step. Here we have generated the gaussian distribution for the current model parameter means and variances. We accomplished that by using the scipy's stat module. After, we used the pdf method to calculate the likelihood of belonging to each data point for each cluster. 301 | 302 | __Step 03: Maximization Step (M step)__ 303 | 304 | 305 | ```python 306 | def step_maximization(X,weights,means,variances,n_compenents,pi): 307 | """M Step 308 | 309 | Parameters 310 | ---------- 311 | X : array-like, shape (n_samples,) 312 | The data. 313 | weights : array-like, shape (n_components,n_samples) 314 | initilized weights array 315 | means : array-like, shape (n_components,) 316 | The means of each mixture component. 317 | variances : array-like, shape (n_components,) 318 | The variances of each mixture component. 319 | n_components : int 320 | The number of clusters 321 | pi: array-like (n_components,) 322 | mixture component weights 323 | 324 | Returns 325 | ------- 326 | means : array-like, shape (n_components,) 327 | The means of each mixture component. 328 | variances : array-like, shape (n_components,) 329 | The variances of each mixture component. 330 | """ 331 | r = [] 332 | for j in range(n_compenents): 333 | r.append((weights[j] * pi[j]) / (np.sum([weights[i] * pi[i] for i in range(n_compenents)], axis=0))) 334 | 335 | means[j] = np.sum(r[j] * X) / (np.sum(r[j])) 336 | variances[j] = np.sum(r[j] * np.square(X - means[j])) / (np.sum(r[j])) 337 | 338 | pi[j] = np.mean(r[j]) 339 | 340 | return variances,means,pi 341 | ``` 342 | 343 | 344 | ```python 345 | def plot_intermediate_steps(means,variances,density=False,save=False,file_name=None): 346 | 347 | plot_pdf(mu1,sigma1,alpha=0.0,linestyle='r--',label='Original Distibutions') 348 | plot_pdf(mu2,sigma2,alpha=0.0,linestyle='r--',label='Original Distibutions') 349 | plot_pdf(mu3,sigma3,alpha=0.0,linestyle='r--',label='Original Distibutions') 350 | 351 | color_gen = (x for x in ['green','blue','orange']) 352 | 353 | for mu,sigma in zip(means,variances): 354 | plot_pdf(mu,sigma,alpha=0.5,label='d',color=next(color_gen)) 355 | if save or file_name is not None: 356 | step = file_name.split("_")[1] 357 | plt.title(f"step: {step}") 358 | plt.savefig(f"steps/{file_name}.png",bbox_inches='tight') 359 | plt.show() 360 | ``` 361 | 362 | Let's implement the training loop. 363 | 364 | 365 | ```python 366 | def train_gmm(data,n_compenents=3,n_steps=50, plot_intermediate_steps_flag=True): 367 | """ Training step of the GMM model 368 | 369 | Parameters 370 | ---------- 371 | data : array-like, shape (n_samples,) 372 | The data. 373 | n_components : int 374 | The number of clusters 375 | n_steps: int 376 | number of iterations to run 377 | """ 378 | 379 | 380 | means,variances,pi = random_init(n_compenents) 381 | for step in range(n_steps): 382 | weights = step_expectation(data,n_compenents,means,variances) 383 | variances,means,pi = step_maximization(X, weights, means, variances, n_compenents, pi) 384 | if plot_intermediate_steps_flag:plot_intermediate_steps(means,variances,)#file_name=f'step_{step+1}') 385 | plot_intermediate_steps(means,variances) 386 | ``` 387 | 388 | When we start the model training, we will do E and M steps according to the n_steps parameter we set. 389 | 390 | But in the actual use cases, you will use the scikit-learn version of the GMM more often. There you can find additional parameters, such as 391 | 392 | tol: defining the model’s stop criteria. EM iterations will stop when the lower bound average gain is below the tol parameter. 393 | 394 | init_params: The method used to initialize the weights 395 | 396 | You may refer to the documentation [here](https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html) 397 | 398 | Alright, let's see how our handcrafted GMM performs. 399 | 400 | 401 | ```python 402 | # train_gmm(X,n_steps=30,plot_intermediate_steps_flag=True) 403 | ``` 404 | 405 | ![model traininng.png](images/model%20traininng.png) 406 | 407 | In the above diagrams, red dashed lines represent the original distribution, while other graphs represent the learned distributions. After the 30th iteration, we can see that our model performed well on this toy dataset. 408 | -------------------------------------------------------------------------------- /images/animated_GMM new.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ransaka/GMM-from-scratch/eeffd90ae0614d3d9bc15c1691c4ead54a6c3ef5/images/animated_GMM new.gif -------------------------------------------------------------------------------- /images/eq1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ransaka/GMM-from-scratch/eeffd90ae0614d3d9bc15c1691c4ead54a6c3ef5/images/eq1.png -------------------------------------------------------------------------------- /images/eq2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ransaka/GMM-from-scratch/eeffd90ae0614d3d9bc15c1691c4ead54a6c3ef5/images/eq2.png -------------------------------------------------------------------------------- /images/equatios3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ransaka/GMM-from-scratch/eeffd90ae0614d3d9bc15c1691c4ead54a6c3ef5/images/equatios3.png -------------------------------------------------------------------------------- /images/model traininng.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ransaka/GMM-from-scratch/eeffd90ae0614d3d9bc15c1691c4ead54a6c3ef5/images/model traininng.png -------------------------------------------------------------------------------- /images/output_39_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ransaka/GMM-from-scratch/eeffd90ae0614d3d9bc15c1691c4ead54a6c3ef5/images/output_39_0.png -------------------------------------------------------------------------------- /images/output_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ransaka/GMM-from-scratch/eeffd90ae0614d3d9bc15c1691c4ead54a6c3ef5/images/output_6_0.png -------------------------------------------------------------------------------- /images/output_9_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ransaka/GMM-from-scratch/eeffd90ae0614d3d9bc15c1691c4ead54a6c3ef5/images/output_9_0.png -------------------------------------------------------------------------------- /images/summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ransaka/GMM-from-scratch/eeffd90ae0614d3d9bc15c1691c4ead54a6c3ef5/images/summary.png --------------------------------------------------------------------------------