├── GMMModel.py ├── GMMclustering.py ├── LICENSE ├── PyGMM.py └── README.md /GMMModel.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2014 Flytxt 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | import numpy as np 21 | from pyspark import SparkContext 22 | from GMMclustering import GMMclustering 23 | from pyspark.mllib.linalg import Vectors, SparseVector 24 | 25 | 26 | class GMMModel(object): 27 | """ 28 | A clustering model derived from the Gaussian Mixture model. 29 | 30 | >>> data = sc.parallelize(np.array([0.5,1,0.75,1,-0.75,0.5,-0.5,0.5,\ 31 | -1,-0.5,-0.75,-0.75,0.75,-0.5,0.75,-0.75]).reshape(8,2)) 32 | >>> model = GMMModel.trainGMM(data,4,10) 33 | >>> np.argmax(model.predict(np.array([0.5,1]))) == \ 34 | np.argmax(model.predict(np.array([0.75,1]))) 35 | True 36 | >>> np.argmax(model.predict(np.array([-0.75,0.5]))) == \ 37 | np.argmax(model.predict(np.array([-0.5,0.5]))) 38 | True 39 | >>> np.argmax(model.predict(np.array([-1,-0.5]))) == \ 40 | np.argmax(model.predict(np.array([0.75,-0.5]))) 41 | False 42 | >>> np.argmax(model.predict(np.array([0.75,-0.75]))) == \ 43 | np.argmax(model.predict(np.array([-0.75,-0.75]))) 44 | False 45 | 46 | >>> sparse_data = ([Vectors.sparse(3, {1: 1.0}),\ 47 | Vectors.sparse(3, {1: 1.1}),\ 48 | Vectors.sparse(3, {2: 1.0}),\ 49 | Vectors.sparse(3, {2: 1.1})]) 50 | >>> sparse_data_rdd = sc.parallelize(sparse_data) 51 | >>> model = GMMModel.trainGMM(sparse_data_rdd,2,10) 52 | >>> np.argmax(model.predict(np.array([0., 1., 0.]))) == \ 53 | np.argmax(model.predict(np.array([0, 1.1, 0.]))) 54 | True 55 | >>> np.argmax(model.predict(Vectors.sparse(3, {1: 1.0}))) == \ 56 | np.argmax(model.predict(Vectors.sparse(3, {2: 1.0}))) 57 | False 58 | >>> np.argmax(model.predict(sparse_data[2])) == \ 59 | np.argmax(model.predict(sparse_data[3])) 60 | True 61 | """ 62 | 63 | @classmethod 64 | def trainGMM(cls, data, n_components, n_iter=100, ct=1e-3): 65 | """ 66 | Train a GMM clustering model. 67 | """ 68 | gmmObj = GMMclustering().fit(data, n_components, n_iter, ct) 69 | return gmmObj 70 | 71 | @classmethod 72 | def resultPredict(cls, gmmObj, data): 73 | """ 74 | Get the result of predict 75 | Return responsibility matrix and cluster labels . 76 | """ 77 | responsibility_matrix = data.map(lambda m: gmmObj.predict(m)) 78 | cluster_labels = responsibility_matrix.map(lambda b: np.argmax(b)) 79 | return responsibility_matrix, cluster_labels 80 | 81 | 82 | def _test(): 83 | import doctest 84 | globs = globals().copy() 85 | globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 86 | (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) 87 | globs['sc'].stop() 88 | if failure_count: 89 | exit(-1) 90 | 91 | 92 | if __name__ == "__main__": 93 | _test() 94 | -------------------------------------------------------------------------------- /GMMclustering.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2014 Flytxt 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | import logging 21 | import numpy as np 22 | from operator import add 23 | from scipy.misc import logsumexp 24 | from pyspark.mllib.linalg import Vectors 25 | from pyspark.mllib.clustering import KMeans 26 | 27 | 28 | class GMMclustering: 29 | logging.basicConfig(level=logging.INFO, 30 | format='%(levelname)s %(message)s') 31 | 32 | def fit(self, data, n_components, n_iter, ct): 33 | """ 34 | Estimate model parameters with the expectation-maximization 35 | algorithm. 36 | 37 | Parameters 38 | ---------- 39 | data - RDD of data points 40 | n_components - Number of components 41 | n_iter - Number of iterations. Default to 100 42 | 43 | Attributes 44 | ---------- 45 | 46 | covariance_type : Type of covariance matrix. 47 | Supports only diagonal covariance matrix. 48 | 49 | ct : Threshold value to check the convergence criteria. 50 | Defaults to 1e-3 51 | 52 | min_covar : Floor on the diagonal of the covariance matrix to prevent 53 | overfitting. Defaults to 1e-3. 54 | 55 | converged : True once converged False otherwise. 56 | 57 | Weights : array of shape (1, n_components) 58 | weights for each mixture component. 59 | 60 | Means : array of shape (n_components, n_dim) 61 | Mean parameters for each mixture component. 62 | 63 | Covars : array of shape (n_components, n_dim) 64 | Covariance parameters for each mixture component 65 | 66 | """ 67 | sc = data.context 68 | covariance_type = 'diag' 69 | converged = False 70 | self.min_covar = 1e-3 71 | 72 | # observation statistics 73 | self.s0 = 0 74 | self.s1 = 0 75 | # To get the no of data points 76 | n_points = data.count() 77 | # To get the no of dimensions 78 | n_dim = data.first().size 79 | 80 | if (n_points == 0): 81 | raise ValueError( 82 | 'Dataset cannot be empty') 83 | if (n_points < n_components): 84 | raise ValueError( 85 | 'Not possible to make (%s) components from (%s) datapoints' % 86 | (n_components, n_points)) 87 | 88 | # Initialize Covars(diagonal covariance matrix) 89 | if hasattr(data.first(), 'indices'): 90 | self.isSparse = 1 91 | 92 | def convert_to_kvPair(eachV): 93 | g = [] 94 | for i in range(eachV.indices.size): 95 | g.append((eachV.indices[i], 96 | (eachV.values[i], eachV.values[i]*eachV.values[i]))) 97 | return g 98 | 99 | def computeVariance(x): 100 | mean = x[1][0]/n_points 101 | sumSq = x[1][1]/n_points 102 | return x[0], sumSq - mean*mean 103 | 104 | cov = [] 105 | kvPair = data.flatMap(convert_to_kvPair) 106 | res = kvPair.reduceByKey(np.add).map(computeVariance) 107 | cov = Vectors.sparse(n_dim, res.collectAsMap()).toArray() + 1e-3 108 | self.Covars = np.tile(cov, (n_components, 1)) 109 | 110 | else: 111 | self.isSparse = 0 112 | cov = [] 113 | for i in range(n_dim): 114 | cov.append(data.map(lambda m: m[i]).variance()+self.min_covar) 115 | self.Covars = np.tile(cov, (n_components, 1)) 116 | 117 | # Initialize Means using MLlib KMeans 118 | self.Means = np.array(KMeans().train(data, n_components).clusterCenters) 119 | # Initialize Weights with the value 1/n_components for each component 120 | self.Weights = np.tile(1.0 / n_components, n_components) 121 | # EM algorithm 122 | # loop until number of iterations or convergence criteria is satisfied 123 | for i in range(n_iter): 124 | 125 | logging.info("GMM running iteration %s " % i) 126 | # broadcasting means,covars and weights 127 | self.meansBc = sc.broadcast(self.Means) 128 | self.covarBc = sc.broadcast(self.Covars) 129 | self.weightBc = sc.broadcast(self.Weights) 130 | # Expectation Step 131 | EstepOut = data.map(self.scoreOnePoint) 132 | # Maximization step 133 | MstepIn = EstepOut.reduce(lambda (w1, x1, y1, z1), (w2, x2, y2, z2): 134 | (w1+w2, x1+x2, y1+y2, z1+z2)) 135 | self.s0 = self.s1 136 | self.mStep(MstepIn[0], MstepIn[1], MstepIn[2], MstepIn[3]) 137 | 138 | # Check for convergence. 139 | if i > 0 and abs(self.s1-self.s0) < ct: 140 | converged = True 141 | logging.info("Converged at iteration %s" % i) 142 | break 143 | 144 | return self 145 | 146 | def scoreOnePoint(self, x): 147 | 148 | """ 149 | Compute the log likelihood of 'x' being generated under the current model 150 | Also returns the probability that 'x' is generated by each component of the mixture 151 | 152 | Parameters 153 | ---------- 154 | x : array of shape (1, n_dim) 155 | Corresponds to a single data point. 156 | 157 | Returns 158 | ------- 159 | log_likelihood_x :Log likelihood of 'x' 160 | prob_x : Resposibility of each cluster for the data point 'x' 161 | 162 | """ 163 | lpr = (self.log_multivariate_normal_density_diag_Nd(x) + np.log(self.Weights)) 164 | log_likelihood_x = logsumexp(lpr) 165 | prob_x = np.exp(lpr-log_likelihood_x) 166 | 167 | if self.isSparse == 1: 168 | temp_wt = np.dot(prob_x[:, np.newaxis], x.toArray()[np.newaxis, :]) 169 | sqVec = Vectors.sparse(x.size, x.indices, x.values**2) 170 | temp_avg = np.dot(prob_x.T[:, np.newaxis], sqVec.toArray()[np.newaxis, :]) 171 | 172 | else: 173 | temp_wt = np.dot(prob_x.T[:, np.newaxis], x[np.newaxis, :]) 174 | temp_avg = np.dot(prob_x.T[:, np.newaxis], (x*x)[np.newaxis, :]) 175 | 176 | return log_likelihood_x, prob_x, temp_wt, temp_avg 177 | 178 | def log_multivariate_normal_density_diag_Nd(self, x): 179 | """ 180 | Compute Gaussian log-density at x for a diagonal model 181 | 182 | """ 183 | 184 | n_features = x.size 185 | 186 | if self.isSparse == 1: 187 | t = Vectors.sparse(x.size, x.indices, x.values**2).dot((1/self.covarBc.value).T) 188 | 189 | else: 190 | t = np.dot(x**2, (1/self.covarBc.value).T) 191 | 192 | lpr = -0.5 * (n_features*np.log(2*np.pi) + np.sum(np.log(self.covarBc.value), 1) + 193 | np.sum((self.meansBc.value ** 2) / self.covarBc.value, 1) 194 | - 2 * x.dot((self.meansBc.value/self.covarBc.value).T) + t) 195 | 196 | return lpr 197 | 198 | def mStep(self, log_sum, prob_sum, weighted_X_sum, weighted_X2_sum): 199 | """ 200 | Perform the Mstep of the EM algorithm. 201 | Updates Means, Covars and Weights using observation statistics 202 | """ 203 | self.s1 = log_sum 204 | inverse_prob_sum = 1.0 / (prob_sum) 205 | self.Weights = (prob_sum / (prob_sum.sum())) 206 | self.Means = (weighted_X_sum * inverse_prob_sum.T[:, np.newaxis]) 207 | self.Covars = ((weighted_X2_sum * inverse_prob_sum.T[:, np.newaxis]) - (self.Means**2) 208 | + self.min_covar) 209 | 210 | def predict(self, x): 211 | """ 212 | Predicts the cluster to which the given instance belongs to 213 | based on the maximum resposibility. 214 | 215 | Parameters 216 | ---------- 217 | x : array of shape (1, n_dim) 218 | Corresponds to a single data point. 219 | 220 | Returns 221 | ------- 222 | resposibility_matrix:membership values of x in each cluster component 223 | """ 224 | if hasattr(x, 'indices'): 225 | self.isSparse = 1 226 | 227 | else: 228 | self.isSparse = 0 229 | 230 | lpr = (self.log_multivariate_normal_density_diag_Nd(x) + np.log(self.Weights)) 231 | log_likelihood_x = logsumexp(lpr) 232 | prob_x = np.exp(lpr-log_likelihood_x) 233 | resposibility_matrix = np.array(prob_x) 234 | return resposibility_matrix 235 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /PyGMM.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2014 Flytxt 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | """ 21 | Gaussian Mixture Model 22 | This implementation of GMM in pyspark uses the Expectation-Maximization algorithm 23 | to estimate the parameters. 24 | """ 25 | import sys 26 | import argparse 27 | import numpy as np 28 | from GMMModel import GMMModel 29 | from pyspark import SparkContext, SparkConf 30 | 31 | 32 | def parseVector(line): 33 | return np.array([float(x) for x in line.split(',')]) 34 | 35 | if __name__ == "__main__": 36 | """ 37 | Parameters 38 | ---------- 39 | input_file : path of the file which contains the comma separated integer data points 40 | n_components : Number of mixture components 41 | n_iter : Number of EM iterations to perform. Default to 100 42 | ct : convergence_threshold.Default to 1e-3 43 | """ 44 | conf = SparkConf().setAppName("GMM") 45 | sc = SparkContext(conf=conf) 46 | 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('input_file', help='input file') 49 | parser.add_argument('n_components', type=int, help='num_of_clusters') 50 | parser.add_argument('--n_iter', default=100, type=int, help='num_of_iterations') 51 | parser.add_argument('--ct', type=float, default=1e-3, help='convergence_threshold') 52 | args = parser.parse_args() 53 | 54 | input_file = args.input_file 55 | lines = sc.textFile(input_file) 56 | data = lines.map(parseVector).cache() 57 | 58 | model = GMMModel.trainGMM(data, args.n_components, args.n_iter, args.ct) 59 | responsibility_matrix, cluster_labels = GMMModel.resultPredict(model, data) 60 | 61 | # Writing the GMM components to files 62 | means_file = input_file.split(".")[0]+"/means" 63 | sc.parallelize(model.Means, 1).saveAsTextFile(means_file) 64 | 65 | covar_file = input_file.split(".")[0]+"/covars" 66 | sc.parallelize(model.Covars, 1).saveAsTextFile(covar_file) 67 | 68 | responsbilities = input_file.split(".")[0]+"/responsbilities" 69 | responsibility_matrix.coalesce(1).saveAsTextFile(responsbilities) 70 | 71 | cluster_file = input_file.split(".")[0]+"/clusters" 72 | cluster_labels.coalesce(1).saveAsTextFile(cluster_file) 73 | sc.stop() 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | GMM 2 | === 3 | 4 | Gaussian Mixture Model Implementation in Pyspark 5 | 6 | GMM algorithm models the entire data set as a finite mixture of Gaussian distributions,each parameterized by a mean vector, a covariance matrix and a mixture weights. Here the probability of each point to belong to each cluster is computed along with the cluster statistics. 7 | 8 | This distributed implementation of GMM in pyspark estimates the parameters using the Expectation-Maximization algorithm and considers only diagonal covariance matrix for each component. 9 | 10 | How to Run 11 | ========== 12 | There are two ways to run this code. 13 | 14 | 1. Using the library in your Python program. 15 | 16 | You can train the GMM model by invoking the function GMMModel.trainGMM(data,k,n_iter,ct) where 17 | 18 | data is an RDD(of dense or Sparse Vector), 19 | k is the number of components/clusters, 20 | n_iter is the number of iterations(default 100), 21 | ct is the convergence threshold(default 1e-3). 22 | 23 | To use this library in your program simply download the GMMModel.py and GMMClustering.py 24 | and add them as Python files along with your own user code as shown below: 25 | ``` 26 | wget https://raw.githubusercontent.com/FlytxtRnD/GMM/master/GMMModel.py 27 | wget https://raw.githubusercontent.com/FlytxtRnD/GMM/master/GMMClustering.py 28 | 29 | ./bin/spark-submit --master --py-files GMMModel.py,GMMclustering.py 30 | 31 | [--n_iter ] [--ct ] 32 | ``` 33 | The returned object "model" has the following attributes **model.Means,model.Covars,model.Weights**. 34 | To get the cluster labels and responsibilty matrix(membership values): 35 | 36 | responsibility_matrix,cluster_labels = GMMModel.resultPredict(model, data) 37 | 38 | 2. Running the example GMM clustering script. 39 | 40 | If you'd like to run our example program directly, also download the PyGMM.py file and invoke it 41 | with spark-submit. 42 | ``` 43 | wget https://raw.githubusercontent.com/FlytxtRnD/GMM/master/PyGMM.py 44 | ./bin/spark-submit --master --py-files GMMModel.py,GMMclustering.py 45 | PyGMM.py 46 | [--n_iter ] [--ct ] 47 | ``` 48 | where master is your spark master URL and input file should contain comma separated numeric values. 49 | Make sure you enter the full path to the downloaded files. 50 | --------------------------------------------------------------------------------