├── .gitignore ├── LICENSE ├── README.md ├── gmm ├── em.cpp ├── em.h ├── gaussian.h ├── gaussian.inl ├── gmm.cpp ├── gmm.h ├── gmm.vcxproj ├── gmm.vcxproj.filters ├── kmeans.h ├── kmeans.inl ├── math_utils.h ├── random_generator.h └── random_generator.inl ├── gmm_example ├── gmm_example.cpp ├── gmm_example.vcxproj ├── gmm_example.vcxproj.filters ├── stdafx.cpp ├── stdafx.h └── targetver.h └── painless_gmm.sln /.gitignore: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # This .gitignore file was automatically created by Microsoft(R) Visual Studio. 3 | ################################################################################ 4 | 5 | # We put our external components in $(SolutionDir)\External, they will not be added to the repo 6 | External/ 7 | 8 | ## Ignore Visual Studio temporary files, build results, and 9 | ## files generated by popular Visual Studio add-ons. 10 | 11 | # User-specific files 12 | *.suo 13 | *.user 14 | *.userosscache 15 | *.sln.docstates 16 | 17 | # User-specific files (MonoDevelop/Xamarin Studio) 18 | *.userprefs 19 | 20 | # Build results 21 | [Dd]ebug/ 22 | [Dd]ebugPublic/ 23 | [Rr]elease/ 24 | [Rr]eleases/ 25 | [Xx]64/ 26 | [Xx]86/ 27 | [Bb]uild/ 28 | bld/ 29 | [Bb]in/ 30 | [Oo]bj/ 31 | 32 | # Visual Studio 2015 cache/options directory 33 | .vs/ 34 | # Uncomment if you have tasks that create the project's static files in wwwroot 35 | #wwwroot/ 36 | 37 | # MSTest test Results 38 | [Tt]est[Rr]esult*/ 39 | [Bb]uild[Ll]og.* 40 | 41 | # NUNIT 42 | *.VisualState.xml 43 | TestResult.xml 44 | 45 | # Build Results of an ATL Project 46 | [Dd]ebugPS/ 47 | [Rr]eleasePS/ 48 | dlldata.c 49 | 50 | # DNX 51 | project.lock.json 52 | artifacts/ 53 | 54 | *_i.c 55 | *_p.c 56 | *_i.h 57 | *.ilk 58 | *.meta 59 | *.obj 60 | *.pch 61 | *.pdb 62 | *.pgc 63 | *.pgd 64 | *.rsp 65 | *.sbr 66 | *.tlb 67 | *.tli 68 | *.tlh 69 | *.tmp 70 | *.tmp_proj 71 | *.log 72 | *.vspscc 73 | *.vssscc 74 | .builds 75 | *.pidb 76 | *.svclog 77 | *.scc 78 | 79 | # Chutzpah Test files 80 | _Chutzpah* 81 | 82 | # Visual C++ cache files 83 | ipch/ 84 | *.aps 85 | *.ncb 86 | *.opendb 87 | *.opensdf 88 | *.sdf 89 | *.cachefile 90 | *.VC.db 91 | 92 | # Visual Studio profiler 93 | *.psess 94 | *.vsp 95 | *.vspx 96 | *.sap 97 | 98 | # TFS 2012 Local Workspace 99 | $tf/ 100 | 101 | # Guidance Automation Toolkit 102 | *.gpState 103 | 104 | # ReSharper is a .NET coding add-in 105 | _ReSharper*/ 106 | *.[Rr]e[Ss]harper 107 | *.DotSettings.user 108 | 109 | # JustCode is a .NET coding add-in 110 | .JustCode 111 | 112 | # TeamCity is a build add-in 113 | _TeamCity* 114 | 115 | # DotCover is a Code Coverage Tool 116 | *.dotCover 117 | 118 | # NCrunch 119 | _NCrunch_* 120 | .*crunch*.local.xml 121 | nCrunchTemp_* 122 | 123 | # MightyMoose 124 | *.mm.* 125 | AutoTest.Net/ 126 | 127 | # Web workbench (sass) 128 | .sass-cache/ 129 | 130 | # Installshield output folder 131 | [Ee]xpress/ 132 | 133 | # DocProject is a documentation generator add-in 134 | DocProject/buildhelp/ 135 | DocProject/Help/*.HxT 136 | DocProject/Help/*.HxC 137 | DocProject/Help/*.hhc 138 | DocProject/Help/*.hhk 139 | DocProject/Help/*.hhp 140 | DocProject/Help/Html2 141 | DocProject/Help/html 142 | 143 | # Click-Once directory 144 | publish/ 145 | 146 | # Publish Web Output 147 | *.[Pp]ublish.xml 148 | *.azurePubxml 149 | 150 | # TODO: Un-comment the next line if you do not want to checkin 151 | # your web deploy settings because they may include unencrypted 152 | # passwords 153 | #*.pubxml 154 | *.publishproj 155 | 156 | # NuGet Packages 157 | *.nupkg 158 | # The packages folder can be ignored because of Package Restore 159 | **/packages/* 160 | # except build/, which is used as an MSBuild target. 161 | !**/packages/build/ 162 | # Uncomment if necessary however generally it will be regenerated when needed 163 | #!**/packages/repositories.config 164 | # NuGet v3's project.json files produces more ignoreable files 165 | *.nuget.props 166 | *.nuget.targets 167 | 168 | # Microsoft Azure Build Output 169 | csx/ 170 | *.build.csdef 171 | 172 | # Microsoft Azure Emulator 173 | ecf/ 174 | rcf/ 175 | 176 | # Microsoft Azure ApplicationInsights config file 177 | ApplicationInsights.config 178 | 179 | # Windows Store app package directory 180 | AppPackages/ 181 | BundleArtifacts/ 182 | 183 | # Visual Studio cache files 184 | # files ending in .cache can be ignored 185 | *.[Cc]ache 186 | # but keep track of directories ending in .cache 187 | !*.[Cc]ache/ 188 | 189 | # Others 190 | ClientBin/ 191 | [Ss]tyle[Cc]op.* 192 | ~$* 193 | *~ 194 | *.dbmdl 195 | *.dbproj.schemaview 196 | *.pfx 197 | *.publishsettings 198 | node_modules/ 199 | orleans.codegen.cs 200 | 201 | # RIA/Silverlight projects 202 | Generated_Code/ 203 | 204 | # Backup & report files from converting an old project file 205 | # to a newer Visual Studio version. Backup files are not needed, 206 | # because we have git ;-) 207 | _UpgradeReport_Files/ 208 | Backup*/ 209 | UpgradeLog*.XML 210 | UpgradeLog*.htm 211 | 212 | # SQL Server files 213 | *.mdf 214 | *.ldf 215 | 216 | # Business Intelligence projects 217 | *.rdl.data 218 | *.bim.layout 219 | *.bim_*.settings 220 | 221 | # Microsoft Fakes 222 | FakesAssemblies/ 223 | 224 | # GhostDoc plugin setting file 225 | *.GhostDoc.xml 226 | 227 | # Node.js Tools for Visual Studio 228 | .ntvs_analysis.dat 229 | 230 | # Visual Studio 6 build log 231 | *.plg 232 | 233 | # Visual Studio 6 workspace options file 234 | *.opt 235 | 236 | # Visual Studio LightSwitch build output 237 | **/*.HTMLClient/GeneratedArtifacts 238 | **/*.DesktopClient/GeneratedArtifacts 239 | **/*.DesktopClient/ModelManifest.xml 240 | **/*.Server/GeneratedArtifacts 241 | **/*.Server/ModelManifest.xml 242 | _Pvt_Extensions 243 | 244 | # LightSwitch generated files 245 | GeneratedArtifacts/ 246 | ModelManifest.xml 247 | 248 | # Paket dependency manager 249 | .paket/paket.exe 250 | 251 | # FAKE - F# Make 252 | .fake/ 253 | *.licenseheader 254 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Alvaro Collet 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tutorial on Painless Gaussian Mixture Models 2 | A real-world implementation of a Gaussian Mixture Model in C++, without the pain. 3 | 4 | ## Introduction to Painless GMM 5 | 6 | A Gaussian Mixture Model (GMM) is a probability distribution defined as a linear combination of weighted Gaussian distributions. It is commonly used in computer vision and image processing tasks, such as estimating a color distribution for foreground/background segmentation, or in clustering problems. This project is intended as an **educational** tool on how to properly implement a Gaussian Mixture Model. 7 | 8 | GMMs are annoying to implement. The math behind GMMs is very easy to understand, but it is not possible to take the formulas and implement them directly. A straight implementation of the GMM formulas leads to underflow errors, singular matrices, divisions-by-zero, and NaNs. The likelihoods involved in GMM are very frequently too small to be directly represented as floating-point numbers (and, even more so, their multiplication). In the following paragraphs and code, I show the changes needed 9 | to take GMM from theory to a robust real-world implementation. Therefore, this is an implementation of GMM without the pain: a Painless GMM. 10 | 11 | ## GMM: The theory 12 | 13 | A GMM is a probability distribution defined as a linear combination of ![equation](https://latex.codecogs.com/gif.latex?k) weighted Gaussian distributions, 14 | 15 |             ![equation](https://latex.codecogs.com/gif.latex?P_%7BGMM%7D%28z_i%20%7C%20%5Cvi%7B%5Cpi%7D%2C%20%5Cvi%7B%5Cmu%7D%2C%20%5Cvi%7B%5CSigma%7D%29%20%3D%20%5Csum_k%20%5Cpi_k%20%5Cmathcal%7BN%7D%28z_i%20%7C%20%5Cmu_k%2C%20%5CSigma_k%29%2C) 16 | 17 | with weights ![equation](https://latex.codecogs.com/gif.latex?%24%5Cpi_k%24), means ![equation](https://latex.codecogs.com/gif.latex?%24%5Cmu_k%24) and covariance matrices ![equation](https://latex.codecogs.com/gif.latex?%24%5CSigma_k%24). We simplify this notation in the following section as ![equation](https://latex.codecogs.com/gif.latex?%24p%28z_i%20%7C%20k%29%20%3D%20%5Cmathcal%7BN%7D%28z_i%20%7C%20%5Cmu_k%2C%20%5CSigma_k%29%24), and ![equation](https://latex.codecogs.com/gif.latex?%24P%28k%29%20%3D%20%5Cpi_k%24). The GMM likelihood then becomes ![equation](https://latex.codecogs.com/gif.latex?%24p%28z_i%29%20%3D%20%5Csum_k%20p%28z_i%7Ck%29%20P%28k%29%24). 18 | 19 | For more information about GMMs, visit Reynold's [gmm tutorial](http://www.ee.iisc.ernet.in/new/people/faculty/prasantg/downloads/GMM_Tutorial_Reynolds.pdf) or the [Wikipedia page](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model). 20 | 21 | ## Training a GMM with Expectation-Maximization (EM) 22 | We start with a data set ![equation](https://latex.codecogs.com/gif.latex?%24%5Cvi%7Bz%7D) of ![equation](https://latex.codecogs.com/gif.latex?N) ![equation](https://latex.codecogs.com/gif.latex?d)-dimensional feature vectors ![equation](https://latex.codecogs.com/gif.latex?z_i) (e.g., ![equation](https://latex.codecogs.com/gif.latex?d%3D3) for RGB color pixels), an initial set of ![equation](https://latex.codecogs.com/gif.latex?K) Gaussian distributions (initialized as described below), and ![equation](https://latex.codecogs.com/gif.latex?K) weights ![equation](https://latex.codecogs.com/gif.latex?%24P%28k%29%24). We use the Expectation-Maximization (EM) algorithm to optimize the Gaussian distributions and weights that maximize the global GMM likelihood ![equation](https://latex.codecogs.com/gif.latex?%24p%28%5Cvi%7Bz%7D%29%20%3D%20%5Cprod_i%20p%28z_i%29%24), that is, the mixture of Gaussian distributions and weights that best fit the data set ![equation](https://latex.codecogs.com/gif.latex?%24%5Cvi%7Bz%7D). 23 | 24 | The EM algorithm is an optimization algorithm which maximizes ![equation](https://latex.codecogs.com/gif.latex?%24p%28%5Cvi%7Bz%7D%29%24) by coordinate ascent, alternating between expectation steps (E-steps) and maximization steps (M-steps). The algorithm starts with an initial E-step. 25 | 26 | In the **E-step**, we determine the responsibility ![equation](https://latex.codecogs.com/gif.latex?p%28k%7Cz_i%29) of each Gaussian distribution for each training data point ![equation](https://latex.codecogs.com/gif.latex?p%28z_i%29), as 27 | 28 |             ![equation](https://latex.codecogs.com/gif.latex?%24%24p_%7Bki%7D%20%5Ctriangleq%20p%28k%7Cz_i%29%20%3D%20%5Cfrac%7Bp%28z_i%7Ck%29P%28k%29%7D%7Bp%28z_i%29%7D%2C%24%24) 29 | 30 | that is, we estimate how likely each Gaussian distribution is to generate the data point ![equation](https://latex.codecogs.com/gif.latex?z_i). 31 | 32 | In the **M-step**, we re-estimate the gaussian distributions and weights given the responsibilities ![equation](https://latex.codecogs.com/gif.latex?p_%7Bki%7D). In particular, we update ![equation](https://latex.codecogs.com/gif.latex?P%28k%29), ![equation](https://latex.codecogs.com/gif.latex?%24%5Cmu_k%24) and ![equation](https://latex.codecogs.com/gif.latex?%24%5CSigma_k%24) as 33 | 34 |            ![equation](https://latex.codecogs.com/gif.latex?%24%24%20P%28k%29%20%3D%20%5Cfrac%7B1%7D%7BN%7D%20%5Csum_i%20p_%7Bki%7D%20%2C%24%24)       ![equation](https://latex.codecogs.com/gif.latex?%24%24%20%5Cmu_k%20%3D%20%5Cfrac%7B%5Csum_i%20p_%7Bki%7Dz_i%7D%7B%5Csum_i%20p_%7Bki%7D%7D%20%2C%24%24)       ![equation](https://latex.codecogs.com/gif.latex?%5CSigma_k%20%3D%20%5Cfrac%7B%5Csum_i%20p_%7Bki%7D%28z_i%20-%20%5Cmu_k%29%28z_i%20-%20%5Cmu_k%29%5ET%7D%7B%5Csum_i%20p_%7Bki%7D%7D.) 35 | 36 | Note that ![equation](https://latex.codecogs.com/gif.latex?p%28z_i%29) and ![equation](https://latex.codecogs.com/gif.latex?%24%5Cmu_k%24) are considered column vectors, so that the outer product ![equation](https://latex.codecogs.com/gif.latex?%24%28z_i%20-%20%5Cmu_k%29%28z_i%20-%20%5Cmu_k%29%5ET%24) results in a ![equation](https://latex.codecogs.com/gif.latex?%243%5Ctimes3%24) matrix. 37 | 38 | We then alternate between the E-Step and M-step until ![equation](https://latex.codecogs.com/gif.latex?%24p%28%5Cvi%7Bz%7D%29%24) does not increase significantly anymore. For example, a common stopping criterion is when ![equation](https://latex.codecogs.com/gif.latex?%24p%28%5Cvi%7Bz%7D%29%24) increases less than 0.01%, or after 100 iterations. 39 | 40 | ## Avoiding underflow errors 41 | The procedure described above for GMM training is correct, but it is not possible to implement directly. A straight implementation of the previous formulas leads to underflow errors and singular matrices, which we must avoid in a robust implementation. 42 | 43 | The likelihoods and responsibilities involved in GMM are very frequently too small to be directly represented as floating-point numbers (and, even more so, their multiplication). An effective solution of the underflow problem is to use log likelihoods and the `logsumexp` trick. 44 | 45 | First, we must use the Gaussian log likelihood ![equation](https://latex.codecogs.com/gif.latex?%24l%28z_i%7Ck%29%20%3D%20%5Clog%20p%28z_i%20%7C%20k%29%24) instead of 46 | the linear likelihood, as 47 | 48 |            ![equation](https://latex.codecogs.com/gif.latex?%24%24%20l%28z_i%7Ck%29%20%3D%20%5Clog%20p%28z_i%7Ck%29%20%3D%20-%5Cfrac%7B1%7D%7B2%7D%20%5Clog%20%28%20%282%5Cpi%29%5Ed%20%5C%7C%5CSigma_k%5C%7C%29%20-%20%5Clog%20%5Cleft%28%20%5Cfrac%7B1%7D%7B2%7D%20%28z_i%20-%20%5Cmu_k%29%5ET%20%5CSigma_k%5E%7B-1%7D%20%28z_i%20-%20%5Cmu_k%29%20%5Cright%29.%20%24%24) 49 | 50 | We must also use the log likelihood of the whole GMM instead of the linear likelihood, but this calculation is slightly more convoluted. Given that the formula ![equation](https://latex.codecogs.com/gif.latex?%24p%28z_i%29%20%3D%20%5Csum_k%20p%28z_i%7Ck%29%20P%28k%29%24) performs likelihood additions, the use of ![equation](https://latex.codecogs.com/gif.latex?%24%5Clog%20p%28z_i%29%24) does not pose any immediate advantage (because we cannot directly add log likelihoods). We use instead the `logsumexp` trick `LSE()`, which 51 | states that 52 | 53 |            ![equation](https://latex.codecogs.com/gif.latex?LSE%28z_i%29%20%5Ctriangleq%20%5Clog%20%5Cleft%28%20%5Csum_i%20z_i%20%5Cright%29%20%3D%20%5Clog%20z_%5Ctext%7Bmax%7D%20+%20%5Clog%20%5Cleft%28%20%5Csum_i%20%5Cexp%20%5Cleft%28%20%5Clog%20z_i%20-%20%5Clog%20z_%5Ctext%7Bmax%7D%20%5Cright%29%20%5Cright%29.) 54 | 55 | In `LSE()`, we scale the ![equation](https://latex.codecogs.com/gif.latex?N) terms in the summation by the largest term ![equation](https://latex.codecogs.com/gif.latex?z_%5Ctext%7Bmax%7D) and convert the scaled terms to linear domain instead. 56 | 57 | Let us give an example of the `logsumexp` trick at work. Let ![equation](https://latex.codecogs.com/gif.latex?%24p%28z_1%29%20%3D%20%5Cexp%28-1000%29%24) and ![equation](https://latex.codecogs.com/gif.latex?%24p%28z_2%29%20%3D%20%5Cexp%28-1001%29%24). We wish to compute ![equation](https://latex.codecogs.com/gif.latex?%24X%20%3D%20%5Clog%28p%28z_1%29+p%28z_2%29%29%24). The direct evaluation of ![equation](https://latex.codecogs.com/gif.latex?%24X%20%3D%20%5Clog%28p%28z_1%29+p%28z_2%29%29%24) 58 | requires calculating ![equation](https://latex.codecogs.com/gif.latex?%24%5Cexp%28-1000%29%24) and ![equation](https://latex.codecogs.com/gif.latex?%24%5Cexp%28-1001%29%24), which causes underflow (regardless of the representation, `float` or `double`), and therefore ![equation](https://latex.codecogs.com/gif.latex?%24X%20%3D%20%5Clog%20%280%20-%200%29%20%3D%20%5Clog%200%20%3D%20-%5Cinf%24). Using `logsumexp`, this is ![equation](https://latex.codecogs.com/gif.latex?%24X%20%3D%20-1000%20+%20log%20%5Cleft%28%5Cexp%280%29%20+%20%5Cexp%28-1%29%5Cright%29%20%5Capprox%20-999.7%24). 59 | 60 | The GMM log likelihood can be expressed as ![equation](https://latex.codecogs.com/gif.latex?%24l%28z_i%29%20%3D%20%5Clog%20p%28z_i%29%20%3D%20LSE%28p%28z_i%7Ck%29%20P%28k%29%29%24). Given that we already calculate ![equation](https://latex.codecogs.com/gif.latex?%24l%28z_i%7Ck%29%24) instead of ![equation](https://latex.codecogs.com/gif.latex?%24p%28z_i%7Ck%29%24), the GMM log likelihood becomes 61 | 62 |            ![equation](https://latex.codecogs.com/gif.latex?l_%5Ctext%7Bmax%7D%28z_i%29%20%26%3D%26%20%5Cmax_k%20%28l%28z_i%7Ck%29%20+%20%5Clog%20P%28k%29%29) 63 | 64 |            ![equation](https://latex.codecogs.com/gif.latex?l%28z_i%29%20%26%3D%26%20%5Clog%20%5Csum_k%20p%28z_i%7Ck%29%20P%28k%29%20%3D%20l_%5Ctext%7Bmax%7D%28z_i%29%20+%20%5Csum_k%20%5Cexp%28%20l%28z_i%7Ck%29%20+%20%5Clog%20P%28k%29%20-%20l_%5Ctext%7Bmax%7D%28z_i%29%29.) 65 | 66 | 67 | Analogously, the E-step becomes 68 | 69 |            ![equation](https://latex.codecogs.com/gif.latex?%24%24l_%7Bki%7D%20%5Ctriangleq%20l%28k%7Cz_i%29%20%3D%20%5Clog%20p%28k%7Cz_i%29%20%3D%20l%28z_i%7Ck%29%20+%20%5Clog%20P%28k%29%20-%20l%28z_i%29%24%24) 70 | 71 | and the responsibilities ![equation](https://latex.codecogs.com/gif.latex?p_%7Bki%7D) are computed from ![equation](https://latex.codecogs.com/gif.latex?l_%7Bki%7D), i.e., ![equation](https://latex.codecogs.com/gif.latex?%24p_%7Bki%7D%20%3D%20%5Cexp%28l_%7Bki%7D%29%24). 72 | 73 | The M-step does not require any changes to prevent underflows. Finally, 74 | the global GMM log likelihood becomes 75 | 76 |            ![equation](https://latex.codecogs.com/gif.latex?l%28%5Cvi%7Bz%7D%29%20%3D%20%5Clog%20p%28%5Cvi%7Bz%7D%29%20%3D%20%5Csum_i%20l%28z_i%29) 77 | 78 | 79 | ## Avoiding singular matrix inversions 80 | 81 | The second main problem in a robust GMM implementation is the appearance of singular matrix inversions. This issue commonly arises with low-variance patches. For example, an image with a section of saturated pixels (e.g., camera is pointing to a light) contains an area with constant color and zero variance. If we attempt to train a GMM in such an image, all pixels in the zero-variance patch will be clustered together, but the evaluation of the Gaussian log likelihood ![equation](https://latex.codecogs.com/gif.latex?%24l%28z_i%7Ck%29%24) will fail because it requires an inverse covariance matrix. A patch with constant color has a singular covariance matrix (all zeros), which is not invertible. 82 | 83 | The simplest solution to this problem is to add bounds to the computation of the estimated covariance matrices. In particular, after 84 | evaluating each covariance matrix, we evaluate its reciprocal condition number 85 | 86 |            ![equation](https://latex.codecogs.com/gif.latex?%5Ctext%7BRCOND%7D%28%5CSigma%29%20%3D%20%5Cfrac%7B1%7D%7B%5C%7C%5CSigma%5C%7C_1%20%5C%7C%5CSigma%5E%7B-1%7D%5C%7C_1%7D.) 87 | 88 | In well-conditioned matrix inversions, `RCOND` is close to 1, whereas it approaches zero for ill-conditioned (close to singular) matrix inversions. In our implementation, we monitor each matrix so that ![equation](https://latex.codecogs.com/gif.latex?%24%5Ctext%7BRCOND%7D%28%5CSigma%29%20%3E%20%5Cepsilon%24), with ![equation](https://latex.codecogs.com/gif.latex?%5Cepsilon%20%3D%2010%5E%7B-10%7D). If this condition is not met, 89 | we force ![equation](https://latex.codecogs.com/gif.latex?%24%5CSigma%24) to be a diagonal matrix with small (but well-conditioned) variance. 90 | 91 | ## Initialization 92 | The training of a GMM requires some initialization for the means and covariances. A common approach is to use K-Means as a starting point. In our case, we implemented a basic K-Means algorithm with Forgy initialization. We use the output cluster centroids and cluster variances to initialize our GMM distribution, with the cluster centroid becoming the GMM means ![equation](https://latex.codecogs.com/gif.latex?%24%5Cmu_k%24) and the cluster variances becoming diagonal covariance matrices ![equation](https://latex.codecogs.com/gif.latex?%24%5CSigma_k%24). 93 | 94 | ## Data whitening 95 | Clustering algorithms like K-Means and GMM show slower convergence properties when the data is badly scaled, or if there is a great disparity in the variance of different features. A common solution to this problem is to perform a *data whitening* step prior to clustering. To *whiten* a data set, we rescale each feature (e.g., the R, G, and B channels in an RGB pixel) in the 96 | feature vector ![equation](https://latex.codecogs.com/gif.latex?%24z_i%24) so that it has unit variance. Consider the scaling matrix 97 | 98 |            ![equation](https://latex.codecogs.com/gif.latex?T%20%3D%20%5Cbegin%7Bpmatrix%7D%20%5Csigma_R%20%26%200%20%26%200%5C%5C%200%20%26%20%5Csigma_G%20%26%200%5C%5C%200%20%26%200%20%26%5Csigma_B%20%5Cend%7Bpmatrix%7D.) 99 | 100 | The data whitening of the feature vector ![equation](https://latex.codecogs.com/gif.latex?%24z_i%24) is then 101 | 102 |            ![equation](https://latex.codecogs.com/gif.latex?%5Cbar%7Bz%7D_i%20%3D%20T%5E%7B-1%7Dz_i%20%3D%20%5Cleft%5B%5Cfrac%7BR_i%7D%7B%5Csigma_%7BR%7D%7D%2C%5C%20%5Cfrac%7BG_i%7D%7B%5Csigma_%7BG%7D%7D%2C%5C%20%5Cfrac%7BB_i%7D%7B%5Csigma_%7BB%7D%7D%5Cright%5D.) 103 | 104 | We use the whitened data set ![equation](https://latex.codecogs.com/gif.latex?%5Cvi%7B%5Cbar%7Bz%7D%7D%20%3D%20%5B%5Cbar%7Bz_0%7D%2C%20%5Cldots%2C%20%5Cbar%7Bz%7D_i%2C%20%5Cldots%2C%20%5Cbar%7Bz_N%7D%5D) as an input to K-Means and then to GMM. After the GMM has converged on the whitened data, we rescale the whitened means ![equation](https://latex.codecogs.com/gif.latex?%5Cbar%7B%5Cmu%7D_k) and covariances ![equation](https://latex.codecogs.com/gif.latex?%5Cbar%7B%5CSigma%7D_k) to their original values. In particular, 105 | ![equation](https://latex.codecogs.com/gif.latex?%5Cmu_k%20%26%3D%26%20T%5Cbar%7B%5Cmu%7D_k) and ![equation](https://latex.codecogs.com/gif.latex?%5CSigma_k%20%26%3D%26%20T%20%5Cbar%7B%5CSigma%7D_k%20T). 106 | -------------------------------------------------------------------------------- /gmm/em.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Alvaro Collet 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | Neither name of this software nor the names of its contributors may be used to 17 | endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | */ 28 | #include "gmm.h" 29 | #include "em.h" 30 | 31 | //---------------------------------------------------------------------------- 32 | AC::GMM::EM::EM(int numObservations, int numModes, double tolerance, int maxIterations) 33 | : m_numTrainingPoints(numObservations) 34 | , m_tolerance(tolerance) 35 | , m_maxIterations(maxIterations) 36 | { 37 | // Reserve memory for our temporary vectors (to avoid allocations while processing) 38 | m_tmpResponsibilities.reserve(numModes); 39 | for (int k = 0; k < numModes; ++k) 40 | m_tmpResponsibilities.emplace_back(std::vector(numObservations)); 41 | } 42 | 43 | //---------------------------------------------------------------------------- 44 | void AC::GMM::EM::UpdateResponsibilities(const std::vector& observations, GMM3D& gmm) 45 | { 46 | _ASSERT(observations.size() == m_numTrainingPoints && m_numTrainingPoints >= gmm.Modes().size() && "Invalid number of observations."); 47 | int numObservations = (int)observations.size(); 48 | int numModes = (int)gmm.Modes().size(); 49 | for (int o = 0; o < numObservations; ++o) { 50 | for (int idxMode = 0; idxMode < numModes; ++idxMode) { 51 | double responsibility = exp(gmm.LogResponsibility(observations[o], idxMode)); 52 | m_tmpResponsibilities[idxMode][o] = IsFinite(responsibility) ? responsibility : 0; 53 | } 54 | } 55 | } 56 | 57 | //---------------------------------------------------------------------------- 58 | void AC::GMM::EM::UpdateWeights(GMM3D& gmm) 59 | { 60 | _ASSERT(m_numTrainingPoints >= gmm.Modes().size() && "Invalid number of observations."); 61 | int numModes = (int)gmm.Modes().size(); 62 | for (int k = 0; k < numModes; ++k) { 63 | AC::OnlineMean weight; 64 | for (int o = 0; o < m_numTrainingPoints; ++o) 65 | weight.Push(m_tmpResponsibilities[k][o]); 66 | gmm.Modes(k)->setWeight(std::max(weight.Mean(), c_SafeMinWeight)); // sum_n(p(k|x_n))/N 67 | } 68 | } 69 | 70 | //---------------------------------------------------------------------------- 71 | void AC::GMM::EM::UpdateMeans(const std::vector& observations, GMM3D& gmm) 72 | { 73 | _ASSERT(observations.size() == m_numTrainingPoints && m_numTrainingPoints >= gmm.Modes().size() && "Invalid number of observations."); 74 | int numModes = (int)gmm.Modes().size(); 75 | AC::OnlineMean sumMean; 76 | for (int k = 0; k < numModes; ++k) { 77 | for (int o = 0; o < m_numTrainingPoints; ++o) 78 | sumMean.Push(observations[o] * m_tmpResponsibilities[k][o]); // sum_n(p(k|x_n)*x_n)/N 79 | // In UpdateWeights, we calculate sum_n( p(k|x_n) )/N, so the /N in sumMean cancels this one, and we get: 80 | gmm.Modes(k)->setMean(sumMean.Mean() / std::max(gmm.Modes(k)->Weight(), c_SafeMinWeight)); // sum_n(p(k|x_n)*x_n)/sum_n(p(k|x_n)) 81 | sumMean.Reset(); // Reset for the next use 82 | } 83 | } 84 | 85 | //---------------------------------------------------------------------------- 86 | void AC::GMM::EM::UpdateCovariances(const std::vector& observations, GMM3D& gmm) 87 | { 88 | _ASSERT(observations.size() == m_numTrainingPoints && m_numTrainingPoints >= gmm.Modes().size() && "Invalid number of observations."); 89 | int numModes = (int)gmm.Modes().size(); 90 | Vec3 centeredObservation; 91 | Mat3 cov; 92 | Mat3 centeredObservationOuterProd; 93 | for (int k = 0; k < numModes; ++k) { 94 | cov = Mat3::Zero(); 95 | for (int o = 0; o < m_numTrainingPoints; ++o) { 96 | centeredObservation = observations[o] - gmm.Modes(k)->Mean(); 97 | centeredObservationOuterProd = centeredObservation * centeredObservation.transpose(); 98 | cov += m_tmpResponsibilities[k][o] * centeredObservationOuterProd; 99 | } 100 | // In UpdateWeights, we calculate sum_n( p(k|x_n) )/N, and we need sum_n( p(k|x_n) ) in the denominator below 101 | cov /= m_numTrainingPoints * std::max(gmm.Modes(k)->Weight(), c_SafeMinWeight); // so we divide by sum_n(p(k|x_n))/N * N 102 | 103 | gmm.Modes(k)->setCovariance(cov); 104 | } 105 | } 106 | 107 | //---------------------------------------------------------------------------- 108 | bool AC::GMM::EM::Process(const std::vector& observations, GMM3D& gmm) 109 | { 110 | _ASSERT(m_numTrainingPoints > 0 && m_numTrainingPoints > gmm.Modes().size() && "Invalid number of observations."); 111 | int numIterations = 0; 112 | double OldLikelihood; 113 | double NewLikelihood = -std::numeric_limits::max(); 114 | do { 115 | // E-Step 116 | UpdateResponsibilities(observations, gmm); 117 | 118 | // M-Step 119 | UpdateWeights(gmm); 120 | UpdateMeans(observations, gmm); 121 | UpdateCovariances(observations, gmm); 122 | 123 | // Update GMM likelihood (stopping condition) 124 | OldLikelihood = NewLikelihood; 125 | NewLikelihood = gmm.LogLikelihood(observations); 126 | } while (abs((NewLikelihood - OldLikelihood) / OldLikelihood) > m_tolerance && numIterations++ < m_maxIterations); 127 | 128 | // Return true if converged 129 | return numIterations < m_maxIterations; 130 | } 131 | -------------------------------------------------------------------------------- /gmm/em.h: -------------------------------------------------------------------------------- 1 | /* 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Alvaro Collet 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | Neither name of this software nor the names of its contributors may be used to 17 | endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | */ 28 | #ifndef __EM_H__ 29 | #define __EM_H__ 30 | 31 | #include "math_utils.h" 32 | #include "gaussian.h" 33 | #include "gmm.h" 34 | 35 | namespace AC 36 | { 37 | namespace GMM 38 | { 39 | 40 | // Expectation-Maximization algorithm for GMM 41 | class EM { 42 | 43 | public: 44 | EM(int numObservations, int numModes, double tolerance = c_EMDefaultTolerance, int maxIterations = c_EMDefaultMaxIterations); 45 | 46 | /// Train gaussian mixture model with the EM algorithm. Given a set of observations and an *Initialized* GMM, this function optimizes the location 47 | /// of the gaussians via iterative expectations and maximizations. 48 | /// The set of observations. 49 | /// [in,out] The computed GMM. This GMM should be already initialized by some other method (e.g., k-means), or at random. 50 | /// true if EM converged before reaching the max number of iterations 51 | bool Process(const std::vector& observations, GMM3D& gmm); 52 | 53 | /// Sets maximum number of iterations of EM. 54 | /// The maximum number of iterations. 55 | void setMaxIterations(int maxIters) { m_maxIterations = maxIters; } 56 | 57 | /// Sets EM tolerance. The stopping condition in EM is the ratio of improvement of the log likelihood, 58 | /// i.e.: tolerance > ((newLogLikelihood - oldLogLikelihood) / oldLogLikelihood) finishes the process. 59 | /// The tolerance. 60 | void setTolerance(double tolerance) { m_tolerance = tolerance; } 61 | 62 | private: 63 | /// Updates the gaussian responsibilities, so that: responsibilities[k][n] = log p_kn = log(p(x_n|k)) + log(P(k)) - log(p(x_n)). 64 | /// (the responsibilities vector is stored internally). This corresponds to the E-step in EM. 65 | /// The input set of observations. 66 | /// The current GMM. 67 | void UpdateResponsibilities(const std::vector& observations, GMM3D& gmm); 68 | 69 | /// Updates the GMM weights P(k) according to the (internally stored) responsibilities vector. This corresponds to the 1st part of the M-Step. 70 | /// [out] The gmm with updated weights. 71 | void UpdateWeights(GMM3D& gmm); 72 | 73 | /// Updates the GMM means according to the observations and responsibilities. This corresponds to the 2nd part of the M-step. 74 | /// The input set of observations. 75 | /// [out] The gmm with updated means. 76 | void UpdateMeans(const std::vector& observations, GMM3D& gmm); 77 | 78 | /// Updates the GMM covariance matrices according to the observations, responsibilities and means. This corresponds to the 3rd part of the M-step. 79 | /// The input set of observations. 80 | /// [out] The gmm with updated covariance matrices. 81 | void UpdateCovariances(const std::vector& observations, GMM3D& gmm); 82 | 83 | std::vector> m_tmpResponsibilities; 84 | int m_numTrainingPoints; // Number of observations used in training 85 | double m_tolerance; // Stopping condition in EM. If ratio of new vs old log likelihoods is lower than tolerance, finish process. 86 | int m_maxIterations; // Max number of iterations of EM. If we do not get under the tolerance in MaxIterations, we give up. 87 | }; 88 | } 89 | } 90 | 91 | #endif 92 | -------------------------------------------------------------------------------- /gmm/gaussian.h: -------------------------------------------------------------------------------- 1 | /* 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Alvaro Collet 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | Neither name of this software nor the names of its contributors may be used to 17 | endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | */ 28 | #ifndef __GAUSSIAN_H__ 29 | #define __GAUSSIAN_H__ 30 | 31 | #include "math_utils.h" 32 | 33 | namespace AC 34 | { 35 | static const double c_SafeDeterminantWithoutUnderflow = 1e-50; // minimum determinant without underflow in Gaussian evaluation 36 | static const double c_SafeMatrixRCOND = 1e-10; // Minimum value for us to consider that a covariance matrix is badly conditioned 37 | static const double c_SafeCovarianceFactor = 1e-10; // Factor added to the diagonal elements of an ill-conditioned covariance matrix to make it well-conditioned. 38 | 39 | template 40 | class GaussianDistribution 41 | { 42 | public: 43 | typedef std::shared_ptr> SP; 44 | 45 | /// Default constructor (initializes with mean 0 and unit spherical covariance). 46 | GaussianDistribution(); 47 | 48 | /// Constructor. 49 | /// The mean. 50 | /// The variance (spherical covariance matrix with 'variance' scaling). 51 | GaussianDistribution(const Vec& mean, double variance, double weight = 1); 52 | 53 | // Copy constructor 54 | GaussianDistribution(const GaussianDistribution& rhs); 55 | 56 | // swap contents 57 | void swap(GaussianDistribution& rhs); 58 | 59 | // Operator= 60 | GaussianDistribution& operator=(GaussianDistribution rhs); 61 | 62 | /// Evaluates Gaussian distribution at the given observation point. 63 | /// The observation. 64 | /// The value of the gaussian pdf(observation). 65 | double Evaluate(const Vec& observation); 66 | 67 | /// Evaluates log Gaussian distribution at the given observation point. 68 | /// The observation. 69 | /// The value of log (gaussian pdf(observation)). 70 | double EvaluateLog(const Vec& observation); 71 | 72 | /// Reinitialize Gaussian distribution with some mean and variance (spherical covariance matrix) 73 | /// The mean. 74 | /// The variance. We build a spherical covariance matrix cov with cov[i,i] = variance. 75 | /// (optional) the weight of this Gaussian distribution. 76 | void Reinitialize(const Vec& mean, double variance, double weight = 1); 77 | 78 | /// Rescale Gaussian (mean and covariance matrix) with the given scaling factors. The end result is that, after scaling Evaluate(obs) will transform into Evaluate(scaling * obs). 79 | /// The scaling factors. 80 | void Rescale(const Vec& scalingFactors); 81 | 82 | const Mat& Covariance() const { return m_covariance; } 83 | const Mat& InvCovariance() const { return m_invCovariance; } 84 | void setCovariance(const Mat& cov); 85 | 86 | const Vec& Mean() const { return m_mean; } 87 | Vec& Mean() { return m_mean; } 88 | void setMean(const Vec& mean) { m_mean = mean; } 89 | 90 | double Weight() const { return m_weight; } 91 | double LogWeight() const { return m_logWeight; } 92 | void setWeight(double w) { m_weight = w; m_logWeight = log(w); } 93 | 94 | /// Enable/disable underflow protection. If enabled, we reshape/resize the covariance matrix to diagonal when it is ill-conditioned. 95 | /// WARNING: This factor may change the shape/size of your covariance matrix. Keep it in mind! 96 | void EnableUnderflowProtection() { m_underflowProtection = true; } 97 | void DisableUnderflowProtection() { m_underflowProtection = false; } 98 | bool UnderflowProtection() const { return m_underflowProtection; } 99 | 100 | const int Dimensions() const { return Dims; } 101 | 102 | private: 103 | Mat m_covariance; // Covariance matrix 104 | Mat m_invCovariance; // Inverse Covariance matrix 105 | Vec m_mean; // Mean of the distribution 106 | double m_logNormFactor; // Normalization factor in log scale. That is: -log (1/(2*PI)^(M/2)) - log (det(Cov)^(1/2)) 107 | double m_weight; // Weight of distribution 108 | double m_logWeight; // log(Weight) of distribution 109 | bool m_underflowProtection; // Use underflow protection in covariance matrix. 110 | /* Temporaries */ 111 | Vec m_tmpCenteredObservation; // Temporary vector to avoid lots of constructor calls 112 | Vec m_tmpProdOutput; // Temporary vector to avoid lots of constructor calls 113 | }; 114 | typedef GaussianDistribution GaussianDistribution3D; 115 | 116 | #include "gaussian.inl" 117 | } 118 | 119 | #endif 120 | -------------------------------------------------------------------------------- /gmm/gaussian.inl: -------------------------------------------------------------------------------- 1 | /* 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Alvaro Collet 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | Neither name of this software nor the names of its contributors may be used to 17 | endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | */ 28 | #define _USE_MATH_DEFINES 29 | #include 30 | 31 | //---------------------------------------------------------------------------- 32 | template 33 | AC::GaussianDistribution::GaussianDistribution() : m_underflowProtection(true) 34 | { 35 | Reinitialize(Vec::Zero() /* zero mean */, 1.0 /* Unit variance */, 1.0 /* Unit weight */); 36 | } 37 | 38 | //---------------------------------------------------------------------------- 39 | template 40 | AC::GaussianDistribution::GaussianDistribution(const Vec& mean, double variance, double weight) : m_underflowProtection(true) 41 | { 42 | Reinitialize(mean, variance, weight); 43 | } 44 | 45 | //---------------------------------------------------------------------------- 46 | template 47 | AC::GaussianDistribution::GaussianDistribution(const GaussianDistribution& rhs) : 48 | m_weight(rhs.m_weight), 49 | m_logWeight(rhs.m_logWeight), 50 | m_logNormFactor(rhs.m_logNormFactor), 51 | m_mean(rhs.m_mean), 52 | m_covariance(rhs.m_covariance), 53 | m_invCovariance(rhs.m_invCovariance), 54 | m_underflowProtection(rhs.m_underflowProtection) 55 | { 56 | } 57 | 58 | //---------------------------------------------------------------------------- 59 | template 60 | void AC::GaussianDistribution::swap(GaussianDistribution& rhs) 61 | { 62 | using std::swap; 63 | swap(m_weight, rhs.m_weight); 64 | swap(m_logWeight, rhs.m_logWeight); 65 | swap(m_logNormFactor, rhs.m_logNormFactor); 66 | swap(m_mean, rhs.m_mean); 67 | swap(m_covariance, rhs.m_covariance); 68 | swap(m_invCovariance, rhs.m_invCovariance); 69 | swap(m_underflowProtection, rhs.m_underflowProtection); 70 | } 71 | 72 | //---------------------------------------------------------------------------- 73 | template 74 | GaussianDistribution& AC::GaussianDistribution::operator=(GaussianDistribution rhs) 75 | { 76 | swap(rhs); 77 | return *this; 78 | } 79 | 80 | //---------------------------------------------------------------------------- 81 | template 82 | void AC::GaussianDistribution::setCovariance(const Mat& cov) 83 | { 84 | m_covariance = cov; 85 | 86 | double determinant; 87 | if (UnderflowProtection()) { 88 | // A very small determinant leads to an invalid logNormFactor. If that happens, reset the covariance matrix to a diagonal one. 89 | determinant = m_covariance.determinant(); 90 | if (determinant < c_SafeDeterminantWithoutUnderflow) { 91 | m_covariance.setIdentity(); 92 | m_covariance *= c_SafeCovarianceFactor; 93 | } 94 | 95 | // Is this covariance matrix degenerate? If RCOND is close to zero, it means that cov is ill-conditioned (close to degenerate). 96 | if (RCOND(m_covariance) < c_SafeMatrixRCOND) { 97 | // Make cov = cov + eye(Dims)*SomeSmallValue to make it better conditioned, even if it's inaccurate. 98 | for (int i = 0; i < Dims; ++i) 99 | m_covariance(i, i) += c_SafeCovarianceFactor; 100 | } 101 | } 102 | 103 | determinant = m_covariance.determinant(); 104 | m_invCovariance = m_covariance.inverse(); 105 | m_logNormFactor = -0.5*log(pow(M_PI * 2, Dims) * determinant); 106 | } 107 | 108 | //---------------------------------------------------------------------------- 109 | template 110 | double AC::GaussianDistribution::EvaluateLog(const Vec& observation) 111 | { 112 | double exponentialTerm = ((observation - Mean()).transpose() * InvCovariance() * (observation - Mean())); 113 | return m_logNormFactor - 0.5 * exponentialTerm; 114 | } 115 | 116 | //---------------------------------------------------------------------------- 117 | template 118 | double AC::GaussianDistribution::Evaluate(const Vec& observation) 119 | { 120 | return exp(EvaluateLog(observation)); 121 | } 122 | 123 | //---------------------------------------------------------------------------- 124 | template 125 | void AC::GaussianDistribution::Reinitialize(const Vec& mean, double variance, double weight) 126 | { 127 | m_mean = mean; 128 | Mat cov; 129 | cov.setIdentity(); 130 | cov *= UnderflowProtection() ? std::max(variance, c_SafeCovarianceFactor) : variance; 131 | setCovariance(cov); 132 | setWeight(weight); 133 | } 134 | 135 | //---------------------------------------------------------------------------- 136 | template 137 | void AC::GaussianDistribution::Rescale(const Vec& scalingFactors) 138 | { 139 | Mat scalingMat; 140 | scalingMat.setZero(); 141 | for (int i = 0; i < scalingFactors.size(); ++i) { 142 | _ASSERT(!UnderflowProtection() || (UnderflowProtection() && scalingFactors[i] > c_SafeCovarianceFactor && "Invalid Scaling Factor (should be > 0)")); 143 | scalingMat(i, i) = scalingFactors[i] > c_SafeCovarianceFactor ? 1 / scalingFactors[i] : 1; 144 | } 145 | 146 | m_mean = scalingMat * m_mean; 147 | 148 | // Rescale covariance matrix: newCov = scaling * oldCov * scaling 149 | m_covariance = scalingMat * (m_covariance * scalingMat); 150 | setCovariance(m_covariance); // Update invCovariance and the gaussian normalization factors 151 | } 152 | -------------------------------------------------------------------------------- /gmm/gmm.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Alvaro Collet 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | Neither name of this software nor the names of its contributors may be used to 17 | endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | */ 28 | #include 29 | #include 30 | #include "gmm.h" 31 | #include "em.h" 32 | #include "math_utils.h" 33 | #include "KMeans.h" 34 | #include "gaussian.h" 35 | 36 | // Local helper functions 37 | namespace 38 | { 39 | //---------------------------------------------------------------------------- 40 | bool CompareWeights3D(const AC::GaussianDistribution3D::SP& g1, const AC::GaussianDistribution3D::SP& g2) 41 | { 42 | return g1->Weight() > g2->Weight(); 43 | } 44 | 45 | //---------------------------------------------------------------------------- 46 | void SortModes(std::vector& modes) 47 | { 48 | std::sort(modes.begin(), modes.end(), CompareWeights3D); 49 | } 50 | } 51 | 52 | //---------------------------------------------------------------------------- 53 | double AC::GMM::LogSumExp(const std::vector& logValues1, const std::vector& logValues2) 54 | { 55 | _ASSERT(logValues1.size() == logValues2.size() && L"Vectors must have the same size"); 56 | 57 | // Trivial case 58 | if (logValues1.empty() || logValues2.empty()) 59 | return 0; 60 | 61 | // Find maximum value in sum of vectors. 62 | double maxLogValue = -std::numeric_limits::max(); 63 | for (int i = 0; i < logValues1.size(); i++) { 64 | double logValue = logValues1[i] + logValues2[i]; 65 | if (logValue > maxLogValue) 66 | maxLogValue = logValue; 67 | } 68 | 69 | double expsum = 0; 70 | for (int i = 0; i < logValues1.size(); i++) 71 | expsum += exp(logValues1[i] + logValues2[i] - maxLogValue); 72 | 73 | return maxLogValue + log(expsum); 74 | } 75 | 76 | //---------------------------------------------------------------------------- 77 | AC::GMM::GMM3D::GMM3D(int numModes) 78 | : m_tmpLogWeights(numModes) 79 | , m_tmpLogLikelihoods(numModes) 80 | , m_globalWeight(1.0) 81 | { 82 | 83 | // Reserve memory for our temporary vectors (to avoid allocations while processing) 84 | m_modes.reserve(numModes); 85 | // Create the set of gaussians in our GMM 86 | for (int k = 0; k < numModes; ++k) 87 | m_modes.emplace_back(new GaussianDistribution3D()); 88 | } 89 | 90 | //---------------------------------------------------------------------------- 91 | AC::GMM::GMM3D::GMM3D(const GMM3D& rhs) 92 | : m_tmpLogLikelihoods(rhs.m_tmpLogLikelihoods.begin(), rhs.m_tmpLogLikelihoods.end()) 93 | , m_tmpLogWeights(rhs.m_tmpLogWeights.begin(), rhs.m_tmpLogWeights.end()) 94 | , m_globalWeight(rhs.m_globalWeight) 95 | { 96 | m_modes.reserve(rhs.m_modes.size()); 97 | for (auto& mode : rhs.m_modes) 98 | m_modes.emplace_back(new GaussianDistribution3D(*mode)); 99 | } 100 | 101 | //---------------------------------------------------------------------------- 102 | double AC::GMM::GMM3D::LogLikelihood(const Vec3& observation) 103 | { 104 | for (int k = 0; k < m_modes.size(); ++k) { 105 | m_tmpLogLikelihoods[k] = Modes(k)->EvaluateLog(observation); 106 | m_tmpLogWeights[k] = Modes(k)->LogWeight(); 107 | } 108 | 109 | // We have to use the LogExpSum trick to evaluate likelihoods to avoid underflows 110 | return LogSumExp(m_tmpLogLikelihoods, m_tmpLogWeights); 111 | } 112 | 113 | //---------------------------------------------------------------------------- 114 | double AC::GMM::GMM3D::Likelihood(const Vec3& observation) 115 | { 116 | double likelihood = exp(LogLikelihood(observation)) * GlobalWeight(); 117 | return IsFinite(likelihood) ? likelihood : 0; 118 | } 119 | 120 | //---------------------------------------------------------------------------- 121 | double AC::GMM::GMM3D::LogLikelihood(const std::vector& observations) 122 | { 123 | double sum = 0; 124 | for (const auto& observation : observations) 125 | sum += LogLikelihood(observation); 126 | 127 | // If NaN or infinite, return minimum possible value for log (corresponding to 0 probability) 128 | return IsFinite(sum) ? sum : -std::numeric_limits::max(); 129 | } 130 | 131 | //---------------------------------------------------------------------------- 132 | double AC::GMM::GMM3D::LogResponsibility(const Vec3& observation, int idxMode) 133 | { 134 | // output: p(k|x_n) = p(x_n|k)p(k)/sum_k(p(x_n|k)) 135 | // We need to use logs to avoid underflow, so: log p(k|x_n) = log(x_n|k) + log(p(k)) - log(sum_k(p(x_n|k)) 136 | return Modes(idxMode)->EvaluateLog(observation) + Modes(idxMode)->LogWeight() - LogLikelihood(observation); 137 | } 138 | 139 | //---------------------------------------------------------------------------- 140 | bool AC::GMM::GMM3D::Process(const std::vector& observations, int numKMeansRestarts, Vec3* scalingFactors, double EMTolerance, int maxIterations) 141 | { 142 | // Compute K-Means to initialize GMM 143 | AC::KMeans3D kmeans((int)Modes().size()); 144 | std::vector assignments(observations.size()); 145 | kmeans.Process(observations, numKMeansRestarts, assignments); 146 | 147 | // Initialize GMM from K-Means results 148 | for (int k = 0; k < kmeans.numCentroids(); ++k) { 149 | Modes(k)->Reinitialize(kmeans.Centroids(k), kmeans.AvgVariance(k)); 150 | Modes(k)->setWeight(kmeans.CentroidAssignmentRatio(k)); 151 | } 152 | 153 | // Use EM to optimize GMM 154 | AC::GMM::EM EMTraining((int)observations.size(), (int)Modes().size(), EMTolerance, maxIterations); 155 | bool success = EMTraining.Process(observations, *this); 156 | 157 | // Sort modes according to weight 158 | SortModes(Modes()); 159 | 160 | // Scale gaussian means and covariances if necessary 161 | if (scalingFactors != nullptr) 162 | for (int k = 0; k < Modes().size(); ++k) 163 | Modes(k)->Rescale(*scalingFactors); 164 | 165 | // Prune bad modes 166 | RemoveBadModes(c_SafeMinWeight); 167 | 168 | // Set global weight based on the number of observations (to compare against other GMMs) 169 | SetGlobalWeight(log(observations.size())); 170 | return success; 171 | } 172 | 173 | //---------------------------------------------------------------------------- 174 | double AC::GMM::GMM3D::ClosestMode(const Vec3& observation, int& mode) const 175 | { 176 | double bestLogProbability = -std::numeric_limits::max(); 177 | mode = INVALID_MODE; 178 | for (int k = 0; k < (int)Modes().size(); ++k) 179 | { 180 | double prob = Modes(k)->EvaluateLog(observation); 181 | if (prob > bestLogProbability) { 182 | mode = k; 183 | bestLogProbability = prob; 184 | } 185 | } 186 | return bestLogProbability; 187 | } 188 | 189 | //---------------------------------------------------------------------------- 190 | size_t AC::GMM::GMM3D::RemoveBadModes(double tolerance) 191 | { 192 | // Erase any mode with weight under a certain tolerance 193 | Modes().erase(std::remove_if( 194 | Modes().begin(), Modes().end(), 195 | [&](mode_type m) { return m->Weight() <= tolerance; }), 196 | Modes().end()); 197 | 198 | // If we erased any mode, need to resize and reset the temporary vectors. 199 | m_tmpLogLikelihoods.resize(Modes().size(), 0.0); 200 | m_tmpLogWeights.resize(Modes().size(), 0.0); 201 | 202 | return Modes().size(); 203 | } 204 | 205 | //---------------------------------------------------------------------------- 206 | double AC::GMM::GMMLikelihoodRatio(GMM3D& gmm1, GMM3D& gmm2, const Vec3& observation) 207 | { 208 | double fgLikelihood = gmm1.Likelihood(observation); 209 | double bgLikelihood = gmm2.Likelihood(observation); 210 | return fgLikelihood / (fgLikelihood + bgLikelihood); 211 | } 212 | 213 | -------------------------------------------------------------------------------- /gmm/gmm.h: -------------------------------------------------------------------------------- 1 | /* 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Alvaro Collet 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | Neither name of this software nor the names of its contributors may be used to 17 | endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | */ 28 | #ifndef __GMM_H__ 29 | #define __GMM_H__ 30 | 31 | #include 32 | #include 33 | #include "gaussian.h" 34 | 35 | namespace AC 36 | { 37 | namespace GMM { 38 | 39 | // Constants 40 | const int c_KMeansRestarts = 10; // Default number of restarts for KMeans initialization 41 | const int INVALID_MODE = -1; // If you call ClosestMode() and there are no modes, the mode returned is INVALID_MODE 42 | const int c_EMDefaultMaxIterations = 10; // Max number of iterations of EM. If we do not get under the tolerance in MaxIterations, we give up. 43 | const double c_EMDefaultTolerance = 1e-4; // Stopping condition in EM. If ratio of new vs old log likelihoods is lower than tolerance, finish process. 44 | 45 | class GMM3D { 46 | public: 47 | typedef std::shared_ptr SP; 48 | typedef GaussianDistribution3D::SP mode_type; 49 | 50 | /// Constructor. 51 | /// Alcollet, 7/17/2013. 52 | /// Number of modes (Gaussians) in GMM. 53 | GMM3D(int numModes); 54 | GMM3D(const GMM3D& rhs); 55 | 56 | /// Compute a GMM from a given set of observations. The GMM is initialized using k-means, and then we use EM to train the full-covariance GMMs. 57 | /// The observations. 58 | /// [optional] Number of restarts in KMeans initialization. 59 | /// [optional] Scaling factor for each observation. Useful if the observations have been whitened (so that the output GMM will still be unwhitened). 60 | /// [optional] Stopping condition in EM. If ratio of new vs old log likelihoods is lower than tolerance, finish process. 61 | /// [optional] Max number of iterations of EM. If we do not get under the EM tolerance in MaxIterations, we give up. 62 | /// true if it succeeds, false if it fails. 63 | bool Process(const std::vector& observations, 64 | int numKMeansRestarts = c_KMeansRestarts, 65 | Vec3* scalingFactors = nullptr, 66 | double EMTolerance = c_EMDefaultTolerance, 67 | int EMMaxIterations = c_EMDefaultMaxIterations); 68 | 69 | /// Compute the log likelihood of the mixture model for an observation x_n, such that: log P(x_n) = log ( sum_k ( p(x_n|k)*p(k) )) 70 | /// The observation. 71 | /// The log likelihood of the mixture model for this observation 72 | double LogLikelihood(const Vec3& observation); 73 | 74 | /// Compute the likelihood of the mixture model for an observation x_n, such that: P(x_n) = sum_k ( p(x_n|k)*p(k) ) 75 | /// The observation. 76 | /// Likelihood of the mixture model for this observation 77 | double Likelihood(const Vec3& observation); 78 | 79 | /// Compute the log likelihood of the mixture model for a set of observations, as: log P(X) = sum_{x_n in X} log P(x_n) 80 | /// The observations. 81 | /// . 82 | double LogLikelihood(const std::vector& observations); 83 | 84 | /// Compute the log responsibility log(p(k|x_n)) = log(p(x_n|k)) + log(P(k)) - log(p(x_n)). 85 | /// The observation x_n 86 | /// The mode index k. 87 | /// The log responsibility log(p(k|x_n)) 88 | double LogResponsibility(const Vec3& observation, int idxMode); 89 | 90 | /// Compute closest mode for a given observation. 91 | /// The observation. 92 | /// [out] Index of the closest mode K to this observation. 93 | /// The log probability that this observation was created from mode K in the GMM. 94 | double ClosestMode(const Vec3& observation, int& mode) const; 95 | 96 | /// Remove modes with weight less than some (small) tolerance. 97 | /// Max weight to consider a mode as 'valid'. 98 | /// The number of valid modes after removal 99 | size_t RemoveBadModes(double tolerance); 100 | 101 | std::vector& Modes() { return m_modes; } 102 | const std::vector& Modes() const { return m_modes; } 103 | 104 | GaussianDistribution3D::SP& Modes(int k) { _ASSERT(k < m_modes.size() && k >= 0 && L"Invalid index"); return m_modes[k]; } 105 | const GaussianDistribution3D::SP& Modes(int k) const { _ASSERT(k < m_modes.size() && k >= 0 && L"Invalid index"); return m_modes[k]; } 106 | 107 | double GlobalWeight() { return m_globalWeight; } 108 | void SetGlobalWeight(double w) { m_globalWeight = w; } 109 | 110 | private: 111 | std::vector m_modes; // k-Vector containing the multiple Gaussians 112 | std::vector m_tmpLogLikelihoods; // k-Vector (temporary) to store the log likelihoods for each Gaussian 113 | std::vector m_tmpLogWeights; // k-Vector (temporary) to store the LogWeights for each Gaussian 114 | double m_globalWeight; // Total weight for this GMM distribution (by default, = 1) 115 | }; 116 | 117 | /// Compute log( sum_n( x1_n * x2_n )) from vectors of logValues1 and logValues2, where logValuesK[n] = log(xK_n). 118 | /// We use the log-sum-exp trick to avoid underflow: log( sum_n( x1_n * x2_n )) = log( sum_n( exp( log x1_n + log x2_n ))) = 119 | /// = log(x1_MAX) + log(x2_MAX) + log( sum_n( exp( log(x1_n) - log(x1_MAX) + log(x2_n) - log(x2_MAX))) 120 | /// The first vector of log values. 121 | /// The second vector of log values. 122 | /// log( sum_n( x1_n * x2_n )) 123 | double LogSumExp(const std::vector& logValues1, const std::vector& logValues2); 124 | 125 | /// Compute the probability of 'observation' to be a sample of gmm1 or gmm2. 126 | /// [in] The first gmm. 127 | /// [in] The second gmm. 128 | /// [in] The observation. 129 | /// The probability of 'observation' to be a sample of gmm1, i.e., p(obs|gmm1) / (p(obs|gmm1)+p(obs|gmm2)) 130 | double GMMLikelihoodRatio(GMM3D& gmm1, GMM3D& gmm2, const Vec3& observation); 131 | } 132 | } 133 | 134 | #endif -------------------------------------------------------------------------------- /gmm/gmm.vcxproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Release 10 | Win32 11 | 12 | 13 | Debug 14 | x64 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | {C65777FC-3FD2-406F-84EF-C5D8D68F15BB} 23 | Win32Proj 24 | gmm 25 | 8.1 26 | 27 | 28 | 29 | StaticLibrary 30 | true 31 | v140 32 | Unicode 33 | 34 | 35 | StaticLibrary 36 | false 37 | v140 38 | true 39 | Unicode 40 | 41 | 42 | StaticLibrary 43 | true 44 | v140 45 | Unicode 46 | 47 | 48 | StaticLibrary 49 | false 50 | v140 51 | true 52 | Unicode 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | Level4 78 | Disabled 79 | WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions) 80 | $(SolutionDir)External 81 | true 82 | 83 | 84 | Windows 85 | 86 | 87 | 88 | 89 | 90 | 91 | Level4 92 | Disabled 93 | _DEBUG;_LIB;%(PreprocessorDefinitions) 94 | $(SolutionDir)External 95 | true 96 | 97 | 98 | Windows 99 | 100 | 101 | 102 | 103 | Level4 104 | 105 | 106 | MaxSpeed 107 | true 108 | true 109 | WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions) 110 | $(SolutionDir)External 111 | true 112 | 113 | 114 | Windows 115 | true 116 | true 117 | 118 | 119 | 120 | 121 | Level4 122 | 123 | 124 | MaxSpeed 125 | true 126 | true 127 | NDEBUG;_LIB;%(PreprocessorDefinitions) 128 | $(SolutionDir)External 129 | true 130 | 131 | 132 | Windows 133 | true 134 | true 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /gmm/gmm.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | Header Files 20 | 21 | 22 | Header Files 23 | 24 | 25 | Header Files 26 | 27 | 28 | Header Files 29 | 30 | 31 | Header Files 32 | 33 | 34 | Header Files 35 | 36 | 37 | 38 | 39 | Header Files 40 | 41 | 42 | Header Files 43 | 44 | 45 | Header Files 46 | 47 | 48 | 49 | 50 | Source Files 51 | 52 | 53 | Source Files 54 | 55 | 56 | -------------------------------------------------------------------------------- /gmm/kmeans.h: -------------------------------------------------------------------------------- 1 | /* 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Alvaro Collet 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | Neither name of this software nor the names of its contributors may be used to 17 | endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | */ 28 | #ifndef __KMEANS_H__ 29 | #define __KMEANS_H__ 30 | 31 | #include 32 | #include 33 | 34 | namespace AC 35 | { 36 | template 37 | class KMeans { 38 | public: 39 | typedef std::shared_ptr> SP; 40 | 41 | /// Constructor. 42 | /// Number of centroids to use in k-means (parameter k). 43 | KMeans(int numMeans); 44 | 45 | /// Compute K-Means algorithm on a set of observations with N random restarts. The resulting centroids are available through the function Centroids(). 46 | /// N-vector of D-dimensional observations. 47 | /// Number of times to restart the algorithm (with random initialization). 48 | /// [out] N-vector of point assignments. Assignment[i] = C --> means that observation[i] is clustered with centroid C. 49 | /// Number of assignment changes at the last iteration (or 0 if the algorithm converged). 50 | int Process(const std::vector& observations, int restarts, std::vector& assignments); 51 | 52 | /// Compute closest centroid for a given observation. 53 | /// The observation. 54 | /// [out] The closest centroid to this observation. 55 | /// The distance from the closest centroid to the observation. 56 | double ClosestCentroid(const Vec& observation, int& assignment); 57 | 58 | /// Compute closest centroids for a set of observations. 59 | /// The set of observations. 60 | /// [out] The computed assignments. 61 | /// [out] Number of assignment changes from the previous iteration of KMeans. 62 | /// The average distance from an observation to its centroid. 63 | double ClosestCentroids(const std::vector& observation, std::vector& assignments, int& numAssignmentChanges); 64 | 65 | /// Updates the centroids given a set of observations and assignments. 66 | /// The observations. 67 | /// The assignments. 68 | void UpdateCentroids(const std::vector& observations, const std::vector& assignments); 69 | 70 | // Get centroids 71 | std::vector& Centroids() { return m_centroids; } 72 | const std::vector& Centroids() const { return m_centroids; } 73 | 74 | // Get centroid(k) 75 | Vec& Centroids(int k) { return m_centroids[k]; } 76 | const Vec& Centroids(int k) const { return m_centroids[k]; } 77 | 78 | // Number of centroids (parameter k) 79 | int numCentroids() { return m_numMeans; } 80 | 81 | // Maximum number of iterations 82 | void setMaxIterations(int maxIterations) { m_maxIterations = maxIterations; } 83 | 84 | // Average distance of points to this centroid 85 | double AvgDistance(int k) { return m_avgDistancesPerCentroid[k].Mean(); } 86 | double AvgVariance(int k) { return m_avgDistancesPerCentroid[k].Variance(); } 87 | 88 | // Ratio of points that belong to centroid k 89 | double CentroidAssignmentRatio(int k) { return (double)m_avgDistancesPerCentroid[k].NumSamples() / (double)m_numTrainingPoints; } 90 | 91 | // Returns number of samples in centroid k 92 | size_t NumSamplesAssignedToCentroid(int k) { return m_avgDistancesPerCentroid[k].NumSamples(); } 93 | 94 | private: 95 | int m_maxIterations; // Maximum number of iterations in KMeans (should not be necessary, in theory, but just in case...) 96 | int m_numMeans; // Parameter K in k-means 97 | std::vector m_centroids; // K-vector of centroids 98 | std::vector> m_avgDistancesPerCentroid; // K-vector containing the average distance to each centroid 99 | std::vector> m_tmpCentroidMeans; // K-Vector with helper classes to compute centroids 100 | int m_numTrainingPoints; // Number of observations used in training 101 | }; 102 | 103 | typedef KMeans KMeans2D; 104 | typedef KMeans KMeans3D; 105 | 106 | /// Basic test for KMeans. Generate a set of noisy observations around a set of centroids, run KMeans, and compare the centroids and assignments. 107 | /// true if passed the test 108 | bool TestKMeans3D(); 109 | } 110 | 111 | #include "KMeans.inl" 112 | 113 | #endif 114 | -------------------------------------------------------------------------------- /gmm/kmeans.inl: -------------------------------------------------------------------------------- 1 | /* 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Alvaro Collet 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | Neither name of this software nor the names of its contributors may be used to 17 | endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | */ 28 | #include "random_generator.h" 29 | 30 | //---------------------------------------------------------------------------- 31 | template 32 | AC::KMeans::KMeans(int numMeans) : m_numMeans(numMeans), m_centroids(numMeans), m_avgDistancesPerCentroid(numMeans), m_tmpCentroidMeans(numMeans), m_maxIterations(100) 33 | { 34 | m_centroids.resize(numMeans); 35 | } 36 | 37 | //---------------------------------------------------------------------------- 38 | template 39 | int AC::KMeans::Process(const std::vector& observations, int numRestarts, std::vector& assignments) 40 | { 41 | if (observations.size() < numCentroids()) 42 | return false; 43 | 44 | this->m_centroids.resize(m_numMeans); 45 | for (int i = 0; i < int(this->m_centroids.size()); i++) 46 | this->m_centroids[i] = observations[0]; 47 | 48 | m_numTrainingPoints = (int)observations.size(); 49 | RandomGenerator subsetGenerator; 50 | std::vector bestCentroids(numCentroids(), observations[0]); 51 | double bestDistance = std::numeric_limits::max(); 52 | subsetGenerator.setLimits(0, (int)observations.size() - 1); 53 | std::vector initialCentroids(numCentroids()); 54 | 55 | // Ensure the assingments and observations have the same size 56 | assignments.resize(observations.size()); 57 | 58 | for (int i = 0; i < numRestarts; ++i) 59 | { 60 | // Choose initial centroids 61 | subsetGenerator.NonRepeatingSubset(initialCentroids); 62 | for (int j = 0; j < numCentroids(); ++j) 63 | Centroids(j) = observations[initialCentroids[j]]; 64 | 65 | int assignmentChanges = 1; 66 | double avgDistance = 0; 67 | int numIterations = 0; 68 | // Process 69 | while (assignmentChanges && numIterations < m_maxIterations) 70 | { 71 | // E-Step 72 | avgDistance = ClosestCentroids(observations, assignments, assignmentChanges); 73 | 74 | // M-Step 75 | UpdateCentroids(observations, assignments); 76 | 77 | numIterations++; 78 | } 79 | // Check if this restart is better than the previous ones 80 | if (avgDistance < bestDistance) 81 | { 82 | bestCentroids = Centroids(); 83 | bestDistance = avgDistance; 84 | } 85 | } 86 | // Update KMeans centroids with the best centroids we found 87 | Centroids() = bestCentroids; 88 | // Recompute best assignments corresponding to best centroids 89 | int numAssignmentChanges = 0; 90 | ClosestCentroids(observations, assignments, numAssignmentChanges); 91 | return numAssignmentChanges; 92 | } 93 | 94 | //---------------------------------------------------------------------------- 95 | template 96 | double AC::KMeans::ClosestCentroid(const Vec& observation, int& assignment) 97 | { 98 | double bestDistance = std::numeric_limits::max(); 99 | // Compute distance to each centroid and pick the lowest. We could use a kd-tree here if we had many centroids. 100 | for (int k = 0; k < numCentroids(); ++k) 101 | { 102 | double distance = (observation - Centroids(k)).norm(); 103 | if (distance < bestDistance) { 104 | assignment = k; 105 | bestDistance = distance; 106 | } 107 | } 108 | return bestDistance; 109 | } 110 | 111 | //---------------------------------------------------------------------------- 112 | template 113 | double AC::KMeans::ClosestCentroids(const std::vector& observations, std::vector& assignments, int& numAssignmentChanges) 114 | { 115 | _ASSERT(observations.size() == assignments.size() && L"Vectors must be the same size"); 116 | _ASSERT(m_avgDistancesPerCentroid.size() == numCentroids() && L"Invalid vector size"); 117 | 118 | // Clean up centroid statistics 119 | for (auto& avgDistance : m_avgDistancesPerCentroid) 120 | avgDistance.Reset(); 121 | 122 | numAssignmentChanges = 0; 123 | int oldAssignment; 124 | 125 | for (int i = 0; i < observations.size(); ++i) 126 | { 127 | // Choose nearest centroid 128 | oldAssignment = assignments[i]; 129 | double distance = ClosestCentroid(observations[i], assignments[i]); 130 | 131 | // Keep track of assignment changes 132 | if (assignments[i] != oldAssignment) 133 | numAssignmentChanges++; 134 | 135 | // Update centroid statistics (online mean) 136 | m_avgDistancesPerCentroid[assignments[i]].Push(distance); 137 | } 138 | 139 | // Compute average variance for each centroid (i.e., our metric to choose one assignment over another) 140 | double avgVariance = 0; 141 | for (int i = 0; i < numCentroids(); ++i) 142 | avgVariance += m_avgDistancesPerCentroid[i].Variance(); 143 | 144 | return avgVariance; 145 | } 146 | 147 | //---------------------------------------------------------------------------- 148 | template 149 | void AC::KMeans::UpdateCentroids(const std::vector& observations, const std::vector& assignments) 150 | { 151 | // Clean centroids first 152 | for (auto& centroid : m_tmpCentroidMeans) 153 | centroid.Reset(); 154 | 155 | // Online computation of mean: http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance 156 | for (int i = 0; i < (int)observations.size(); ++i) 157 | m_tmpCentroidMeans[assignments[i]].Push(observations[i]); 158 | 159 | // Copy centroids to its permanent storage 160 | for (int k = 0; k < numCentroids(); ++k) 161 | { 162 | if (m_tmpCentroidMeans[k].NumSamples() != 0) 163 | Centroids(k) = m_tmpCentroidMeans[k].Mean(); 164 | } 165 | } -------------------------------------------------------------------------------- /gmm/math_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Alvaro Collet 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | Neither name of this software nor the names of its contributors may be used to 17 | endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | */ 28 | #ifndef __MATH_UTILS_H__ 29 | #define __MATH_UTILS_H__ 30 | 31 | //---------------------------------------------------------------------------- 32 | // Use Eigen by default (you will need to add it to your path) 33 | #include 34 | #include 35 | #include // Necessary to use aligned Eigen::Vector2f inside an std::vector 36 | 37 | namespace AC 38 | { 39 | typedef Eigen::Vector2d Vec2; 40 | typedef Eigen::Vector3d Vec3; 41 | typedef Eigen::Matrix2d Mat2; 42 | typedef Eigen::Matrix3d Mat3; 43 | 44 | // General version 45 | template 46 | inline void SetZero(T& val) 47 | { 48 | val = 0; 49 | } 50 | 51 | // Specialized version for Eigen 52 | inline void SetZero(Vec3 vec) 53 | { 54 | vec.setZero(); 55 | } 56 | 57 | // Utility to compute the mean of a list of values incrementally (more accurate), as in: http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance 58 | template 59 | class OnlineMean { 60 | public: 61 | OnlineMean() : m_numSamples(0) 62 | { 63 | SetZero(m_currentMean); 64 | } 65 | 66 | /// Push new value through the online mean. 67 | /// Alcollet, 7/30/2013. 68 | /// The value to push. 69 | /// The updated mean. 70 | const T& Push(const T& value) 71 | { 72 | ++m_numSamples; 73 | T delta = value - m_currentMean; 74 | m_currentMean = m_currentMean + delta / float(m_numSamples); 75 | return m_currentMean; 76 | } 77 | 78 | /// Return the current mean value. 79 | /// The current mean value. 80 | const T& Mean() const 81 | { 82 | _ASSERT(m_numSamples != 0 && "Mean() is ill-defined since there are no elements pushed yet"); 83 | return m_currentMean; 84 | } 85 | 86 | /// Resets this object (to zero mean, zero elements). 87 | void Reset() 88 | { 89 | m_numSamples = 0; 90 | SetZero(m_currentMean); 91 | } 92 | 93 | /// Return the current number of elements. 94 | /// The total number of elements. 95 | int NumSamples() 96 | { 97 | return static_cast(m_numSamples); 98 | } 99 | 100 | private: 101 | size_t m_numSamples; 102 | T m_currentMean; 103 | }; 104 | 105 | //---------------------------------------------------------------------------- 106 | /// Compute mean and variance in a single loop using the online approach described in: 107 | /// Donald E. Knuth (1998). The Art of Computer Programming, volume 2: Seminumerical Algorithms, 3rd edn., p. 232. Boston: Addison-Wesley. 108 | /// Also see: http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance 109 | template 110 | class OnlineMeanVariance 111 | { 112 | public: 113 | OnlineMeanVariance() 114 | : m_numSamples(0) 115 | , m_currentMean(0) 116 | , m_currentMeanSq(0) 117 | , m_currentVariance(0) 118 | {} 119 | 120 | // Add a new value 121 | void Push(T value) 122 | { 123 | ++m_numSamples; 124 | T diff = value - m_currentMean; 125 | m_currentMean += diff / static_cast(m_numSamples); 126 | m_currentMeanSq += diff * (value - m_currentMean); 127 | if (m_numSamples > 1) 128 | m_currentVariance = m_currentMeanSq / (static_cast(m_numSamples - 1)); 129 | } 130 | 131 | // Reset calculations 132 | void Reset() 133 | { 134 | m_numSamples = 0; 135 | SetZero(m_currentMean); 136 | SetZero(m_currentMeanSq); 137 | SetZero(m_currentVariance); 138 | } 139 | 140 | // Get current number of samples 141 | size_t NumSamples() { return m_numSamples; } 142 | 143 | // Get current mean 144 | T Mean() { return m_currentMean; } 145 | 146 | // Get current variance 147 | T Variance() { return m_currentVariance; } 148 | 149 | private: 150 | size_t m_numSamples; 151 | T m_currentMeanSq; 152 | T m_currentMean; 153 | T m_currentVariance; 154 | }; 155 | 156 | //---------------------------------------------------------------------------- 157 | static const double c_SafeMinVariance = 1e-10; // Minimum variance we tolerate without being ill-conditioned 158 | static const double c_SafeMinWeight = 1e-20; // Minimum weight we consider valid 159 | 160 | //---------------------------------------------------------------------------- 161 | // Normalize each dimension independently in all observations so that they have zero mean and unit variance. 162 | template 163 | void WhitenObservations(std::vector& observations, Vec& whiteningFactors, double SafeMinVariance = c_SafeMinVariance) 164 | { 165 | // Compute variance for each channel and store their inverses in whiteningFactors 166 | OnlineMeanVariance onlineVariance; 167 | for (int b = 0; b < whiteningFactors.Size(); ++b) { 168 | for (auto& observation : observations) 169 | onlineVariance.Push((double)observation[b]); 170 | whiteningFactors[b] = onlineVariance.Variance() > SafeMinVariance ? 1 / onlineVariance.Variance() : 1; // If variance is too small, don't invert 171 | onlineVariance.Reset(); 172 | } 173 | 174 | // Divide each observation by its variance 175 | for (int b = 0; b < whiteningFactors.size(); ++b) { 176 | for (auto& observation : observations) 177 | observation[b] *= whiteningFactors[b]; 178 | } 179 | } 180 | 181 | //---------------------------------------------------------------------------- 182 | /// Compute Norm_1 on a matrix. This corresponds to the maximum sum of column values over all columns. 183 | /// Type of the typename matrix. 184 | /// Matrix for which to calculate the Norm_1 185 | /// Norm_1(M) 186 | template 187 | double Norm1(const Mat& m) 188 | { 189 | // This is the conceptual code that should be executed 190 | /* 191 | double maxSumCols = 0; 192 | for (int j = 0; j < m.cols(); ++j) { 193 | double sumCol = 0; 194 | // Sum of all absolute values for one column 195 | for (int i = 0; i < m.rows(); ++i) 196 | sumCol += (double)abs(m(i, j)); 197 | // Maximum sum of column values over all columns 198 | if (sumCol > maxSumCols) 199 | maxSumCols = sumCol; 200 | } 201 | return maxSumCols; 202 | */ 203 | // This is the equivalent code in Eigen. Alternatively, we could also use Eigen::lpNorm<1>(m); 204 | return m.colwise().sum().maxCoeff(); 205 | } 206 | 207 | //---------------------------------------------------------------------------- 208 | /// Compute the reciprocal condition number of matrix M. This value should be close to 1.0 for well conditioned matrices. 209 | /// Type of the typename matrix. Should have ".inverse()" function. 210 | /// The const Mat& to process. 211 | /// . 212 | template 213 | double RCOND(const Mat& m) 214 | { 215 | return 1 / (Norm1(m) * Norm1(m.inverse())); 216 | } 217 | 218 | //---------------------------------------------------------------------------- 219 | /// Query if 'arg' is finite (not infinite or NaN) 220 | template bool IsFinite(T arg) 221 | { 222 | return arg == arg && 223 | arg != std::numeric_limits::infinity() && 224 | arg != -std::numeric_limits::infinity(); 225 | } 226 | } 227 | #endif -------------------------------------------------------------------------------- /gmm/random_generator.h: -------------------------------------------------------------------------------- 1 | /* 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Alvaro Collet 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | Neither name of this software nor the names of its contributors may be used to 17 | endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | */ 28 | #ifndef __RANDOM_GENERATOR_H__ 29 | #define __RANDOM_GENERATOR_H__ 30 | 31 | #include 32 | 33 | namespace AC { 34 | 35 | class RandomGenerator { 36 | public: 37 | RandomGenerator() : m_minValue(0), m_maxValue(1), distribution(m_minValue, m_maxValue) {} 38 | RandomGenerator(int minValue, int maxValue) : m_minValue(minValue), m_maxValue(maxValue) { distribution.param(std::uniform_int::param_type(minValue, maxValue)); } 39 | 40 | void setLimits(int minValue, int maxValue) { m_minValue = minValue; m_maxValue = maxValue; distribution.param(std::uniform_int::param_type(minValue, maxValue)); } 41 | 42 | /// Draw an integer random number in range [minValue, maxValue]. 43 | /// The random number in range [minValue, maxValue]. 44 | int draw() { return distribution(generator); } 45 | 46 | /// Return a subset of elements in range [m_minValue, m_maxValue] which are not repeated. 47 | /// Type of the typename vector (e.g., std::vector or std::array). 48 | /// [in,out] The vector of non-repeated values in range [m_minValue, m_maxValue] 49 | /// true if it succeeds, false if it fails. 50 | template 51 | bool NonRepeatingSubset(Vec& values, int maxTries = c_MaxIterations); 52 | 53 | /// Query if idx is already in sampleIDs or not. 54 | /// Type of the typename vector (e.g., std::vector or std::array). 55 | /// [in,out] The vector of sampled IDs. 56 | /// The index. 57 | /// true if unique sample identifier <typename vec>, false if not. 58 | template 59 | bool IsUniqueSampleID(Vec& sampleIDs, int idx); 60 | 61 | private: 62 | std::default_random_engine generator; 63 | std::uniform_int_distribution distribution; 64 | int m_minValue; 65 | int m_maxValue; 66 | static const int c_MaxIterations = 25; // Maximum number of tries to get a non-repeating subset 67 | }; 68 | 69 | #include "random_generator.inl" 70 | 71 | } 72 | 73 | #endif -------------------------------------------------------------------------------- /gmm/random_generator.inl: -------------------------------------------------------------------------------- 1 | /* 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Alvaro Collet 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | Neither name of this software nor the names of its contributors may be used to 17 | endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | */ 28 | //----------------------------------------------------------------------------- 29 | template 30 | bool AC::RandomGenerator::NonRepeatingSubset(Vec& sampleIDs, int maxTries) 31 | { 32 | // Invalid case 33 | int numElements = (int)sampleIDs.size(); 34 | if (sampleIDs.empty() || (m_maxValue - m_minValue + 1) < numElements) 35 | return false; 36 | 37 | // Trivial case (we only want a non-repeating subset, we don't care about ordering) 38 | if ((m_maxValue - m_minValue + 1) == numElements) { 39 | for (int i = 0; i < numElements; ++i) 40 | sampleIDs[i] = m_minValue + i; 41 | return true; 42 | } 43 | 44 | // Regular case (you should set your minimum and maximum value in setLimits) 45 | sampleIDs[0] = draw(); 46 | 47 | for (int i = 1; i < numElements; ++i) { 48 | int numIterations = 0; 49 | sampleIDs[i] = draw(); 50 | while (!IsUniqueSampleID(sampleIDs, i) && numIterations < c_MaxIterations) 51 | { 52 | sampleIDs[i] = (sampleIDs[i] + 1) % numElements; // Keep changing the random samples until the sequence is unique 53 | numIterations++; 54 | } 55 | if (numIterations == maxTries) // Give up if we try too many times 56 | return false; 57 | } 58 | return true; 59 | } 60 | 61 | //----------------------------------------------------------------------------- 62 | template 63 | bool AC::RandomGenerator::IsUniqueSampleID(Vec& sampleIDs, int idx) 64 | { 65 | _ASSERT(sampleIDs.size() > idx && L"Invalid input"); 66 | bool unique = true; 67 | // Query if idx is unique inside sampleIDs. For large vector sizes this will be inefficient and we should use a std::map, 68 | // but our typical use case is for 3-vectors (in which this is faster). 69 | for (int j = 0; j < idx; ++j) { 70 | if (sampleIDs[idx] == sampleIDs[j]) { 71 | unique = false; 72 | break; 73 | } 74 | } 75 | return unique; 76 | } -------------------------------------------------------------------------------- /gmm_example/gmm_example.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Alvaro Collet 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | Neither name of this software nor the names of its contributors may be used to 17 | endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | */ 28 | 29 | // gmm_example.cpp : Defines the entry point for the console application. 30 | // 31 | 32 | #include "stdafx.h" 33 | #include "gmm/gmm.h" // Remember, you need to put the $(SolutionDir) in "Additional Include Directories" 34 | #include "gmm/kmeans.h" 35 | #include "gmm/math_utils.h" 36 | #include // Remember, you need to put the path to Eigen in "Additional Include Directories" 37 | #include 38 | 39 | /////////////////////////////////////////////////////////////////////////////// 40 | // Basic test to show how to use KMeans 41 | bool TestKMeans3D() 42 | { 43 | using namespace AC; 44 | 45 | // Create some centroids 46 | int numCentroids = 5; 47 | std::vector centroids(numCentroids); 48 | float scale = 10; 49 | centroids[0] = Vec3(-scale, 0, 0); 50 | centroids[1] = Vec3(0, -scale, 0); 51 | centroids[2] = Vec3(0, 0, -scale); 52 | centroids[3] = Vec3(scale, 0, 0); 53 | centroids[4] = Vec3(0, scale, 0); 54 | 55 | // Create some observations 56 | int numObservationsPerCentroid = 100; 57 | int numObservations = numObservationsPerCentroid * numCentroids; 58 | std::default_random_engine generator; 59 | std::uniform_real_distribution distribution(-scale / 4, scale / 4); 60 | std::vector observations(numObservations); 61 | std::vector assignmentsGT(numObservations); 62 | auto noise = std::bind(distribution, generator); 63 | Vec3 observation; 64 | for (int k = 0; k < numCentroids; ++k) { 65 | for (int i = 0; i < numObservationsPerCentroid; ++i) { 66 | observations[numObservationsPerCentroid*k + i] = Vec3(centroids[k][0] + float(noise()), centroids[k][1] + float(noise()), centroids[k][2] + float(noise())); assignmentsGT[numObservationsPerCentroid*k + i] = k; 67 | } 68 | } 69 | 70 | // Run kmeans 71 | KMeans kmeans(numCentroids); 72 | std::vector assignmentsKMeans(numObservations); 73 | kmeans.setMaxIterations(100); 74 | kmeans.Process(observations, 30 /* restart 30 times with different initializations */, assignmentsKMeans); 75 | 76 | // Compute average best distance between centroids and KMeans 77 | double avgBestDistance = 0; 78 | int assignment; 79 | for (auto& centroidGT : centroids) 80 | avgBestDistance += kmeans.ClosestCentroid(centroidGT, assignment); 81 | avgBestDistance /= (int)centroids.size(); 82 | 83 | // Compute assignment differences between the ground truth and KMeans 84 | int badAssignments = 0; 85 | for (int i = 0; i < (int)assignmentsGT.size(); ++i) { 86 | for (int j = i; j < (int)assignmentsGT.size(); ++j) { 87 | if ((assignmentsGT[i] == assignmentsGT[j] && assignmentsKMeans[i] != assignmentsKMeans[j]) || 88 | (assignmentsGT[i] != assignmentsGT[j] && assignmentsKMeans[i] == assignmentsKMeans[j])) { 89 | badAssignments++; 90 | } 91 | } 92 | } 93 | // Assignments are pairwise, so there is a total of N*(N-1)/2 possible bad assignments; 94 | int maxAssignments = (int)assignmentsGT.size() * ((int)assignmentsGT.size() - 1) / 2; 95 | // If centroids are within 10% distance of GT, and there are less than 10% assignment errors, declare success. 96 | if (avgBestDistance < scale && badAssignments < (int)maxAssignments / 10) 97 | return true; 98 | else 99 | return false; 100 | } 101 | 102 | /////////////////////////////////////////////////////////////////////////////// 103 | // Basic test for GMM. Create a noisy set of observations from a (known) multivariate gaussian distribution, fit GMM to it, and compare the differences in labeling. 104 | bool TestGMM3D() 105 | { 106 | using namespace AC; 107 | 108 | // Create some centroids 109 | int numModes = 5; 110 | std::vector centroids(numModes); 111 | double scale = 10; 112 | centroids[0] = Vec3(-scale, 0, 0); 113 | centroids[1] = Vec3(0, -scale, 0); 114 | centroids[2] = Vec3(0, 0, -scale); 115 | centroids[3] = Vec3(scale, 0, 0); 116 | centroids[4] = Vec3(0, scale, 0); 117 | 118 | // Create some covariance matrices, and rotate and scale them 119 | std::vector covariances(numModes); 120 | covariances[0] = Vec3(1.0, 1.0, 1.0).asDiagonal() * Eigen::AngleAxisd(0.0, Vec3::UnitX()); 121 | covariances[1] = Vec3(2.0, 1.0, 0.5).asDiagonal() * Eigen::AngleAxisd(0.2, Vec3::UnitY()); 122 | covariances[2] = Vec3(1.0, 2.0, 1.0).asDiagonal() * Eigen::AngleAxisd(0.4, Vec3::UnitZ()); 123 | covariances[3] = Vec3(1.0, 1.0, 2.0).asDiagonal() * Eigen::AngleAxisd(0.6, Vec3::UnitX()); 124 | covariances[4] = Vec3(0.75, 1.0, 0.75).asDiagonal() * Eigen::AngleAxisd(0.8, Vec3::UnitY()); 125 | 126 | // Create some observations 127 | int numObservations = 100; 128 | std::default_random_engine generator; 129 | std::uniform_real_distribution distribution(-scale / 4, scale / 4); 130 | std::vector observations(numObservations * numModes); 131 | std::vector assignmentsGT(numObservations * numModes); 132 | auto noise = std::bind(distribution, generator); 133 | Vec3 observation; 134 | for (int k = 0; k < numModes; ++k) { 135 | for (int i = 0; i < numObservations; ++i) { 136 | observations[numObservations*k + i] = covariances[k] * Vec3(centroids[k][0] + noise(), centroids[k][1] + noise(), centroids[k][2] + noise()); 137 | assignmentsGT[numObservations*k + i] = k; 138 | } 139 | } 140 | 141 | // Run GMM 142 | GMM::GMM3D gmm(numModes); 143 | gmm.Process(observations); 144 | 145 | // Compute hard assignments and probabilities 146 | std::vector assignmentsGMM(numObservations * numModes); 147 | double avgProbability = 0; 148 | AC::OnlineMean avgLogProb; 149 | for (int i = 0; i < (int)observations.size(); ++i) { 150 | gmm.ClosestMode(observations[i], assignmentsGMM[i] /* output value */); 151 | avgLogProb.Push(gmm.LogResponsibility(observations[i], assignmentsGMM[i])); 152 | } 153 | avgProbability = exp(avgLogProb.Mean()); 154 | 155 | // Compute assignment differences between the ground truth and KMeans 156 | int badAssignments = 0; 157 | for (int i = 0; i < (int)assignmentsGT.size(); ++i) { 158 | for (int j = i; j < (int)assignmentsGT.size(); ++j) { 159 | if ((assignmentsGT[i] == assignmentsGT[j] && assignmentsGMM[i] != assignmentsGMM[j]) || 160 | (assignmentsGT[i] != assignmentsGT[j] && assignmentsGMM[i] == assignmentsGMM[j])) { 161 | badAssignments++; 162 | } 163 | } 164 | } 165 | // Assignments are pairwise, so there is a total of N*(N-1)/2 possible bad assignments; 166 | int maxAssignments = (int)assignmentsGT.size() * ((int)assignmentsGT.size() - 1) / 2; 167 | // If average point probability is larger than 0.9, and there are less than 10% assignment errors, declare success. 168 | if (avgProbability > 0.9 && badAssignments < (int)maxAssignments / 10) 169 | return true; 170 | else 171 | return false; 172 | } 173 | 174 | /////////////////////////////////////////////////////////////////////////////// 175 | int main() 176 | { 177 | bool isKMeansOK = TestKMeans3D(); 178 | bool isGMMOK = TestGMM3D(); 179 | return isKMeansOK && isGMMOK; 180 | } 181 | 182 | -------------------------------------------------------------------------------- /gmm_example/gmm_example.vcxproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Release 10 | Win32 11 | 12 | 13 | Debug 14 | x64 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | {2BC833B9-98A3-4117-90EF-0FF768000C20} 23 | Win32Proj 24 | gmm_example 25 | 8.1 26 | 27 | 28 | 29 | Application 30 | true 31 | v140 32 | Unicode 33 | 34 | 35 | Application 36 | false 37 | v140 38 | true 39 | Unicode 40 | 41 | 42 | Application 43 | true 44 | v140 45 | Unicode 46 | 47 | 48 | Application 49 | false 50 | v140 51 | true 52 | Unicode 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | true 74 | 75 | 76 | true 77 | 78 | 79 | false 80 | 81 | 82 | false 83 | 84 | 85 | 86 | Use 87 | Level4 88 | Disabled 89 | WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) 90 | true 91 | $(SolutionDir); $(SolutionDir)External 92 | true 93 | 94 | 95 | Console 96 | true 97 | 98 | 99 | 100 | 101 | Use 102 | Level4 103 | Disabled 104 | _DEBUG;_CONSOLE;%(PreprocessorDefinitions) 105 | true 106 | $(SolutionDir); $(SolutionDir)External 107 | true 108 | 109 | 110 | Console 111 | true 112 | 113 | 114 | 115 | 116 | Level4 117 | Use 118 | MaxSpeed 119 | true 120 | true 121 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 122 | true 123 | $(SolutionDir); $(SolutionDir)External 124 | true 125 | 126 | 127 | Console 128 | true 129 | true 130 | true 131 | 132 | 133 | 134 | 135 | Level4 136 | Use 137 | MaxSpeed 138 | true 139 | true 140 | NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 141 | true 142 | $(SolutionDir); $(SolutionDir)External 143 | true 144 | 145 | 146 | Console 147 | true 148 | true 149 | true 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | Create 160 | Create 161 | Create 162 | Create 163 | 164 | 165 | 166 | 167 | {c65777fc-3fd2-406f-84ef-c5d8d68f15bb} 168 | 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /gmm_example/gmm_example.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | Header Files 20 | 21 | 22 | Header Files 23 | 24 | 25 | 26 | 27 | Source Files 28 | 29 | 30 | Source Files 31 | 32 | 33 | -------------------------------------------------------------------------------- /gmm_example/stdafx.cpp: -------------------------------------------------------------------------------- 1 | // stdafx.cpp : source file that includes just the standard includes 2 | // gmm_example.pch will be the pre-compiled header 3 | // stdafx.obj will contain the pre-compiled type information 4 | 5 | #include "stdafx.h" 6 | 7 | // TODO: reference any additional headers you need in STDAFX.H 8 | // and not in this file 9 | -------------------------------------------------------------------------------- /gmm_example/stdafx.h: -------------------------------------------------------------------------------- 1 | // stdafx.h : include file for standard system include files, 2 | // or project specific include files that are used frequently, but 3 | // are changed infrequently 4 | // 5 | 6 | #pragma once 7 | 8 | #include "targetver.h" 9 | 10 | #include 11 | #include 12 | 13 | 14 | 15 | // TODO: reference additional headers your program requires here 16 | -------------------------------------------------------------------------------- /gmm_example/targetver.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // Including SDKDDKVer.h defines the highest available Windows platform. 4 | 5 | // If you wish to build your application for a previous Windows platform, include WinSDKVer.h and 6 | // set the _WIN32_WINNT macro to the platform you wish to support before including SDKDDKVer.h. 7 | 8 | #include 9 | -------------------------------------------------------------------------------- /painless_gmm.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 14 4 | VisualStudioVersion = 14.0.25123.0 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "gmm", "gmm\gmm.vcxproj", "{C65777FC-3FD2-406F-84EF-C5D8D68F15BB}" 7 | EndProject 8 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "gmm_example", "gmm_example\gmm_example.vcxproj", "{2BC833B9-98A3-4117-90EF-0FF768000C20}" 9 | EndProject 10 | Global 11 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 12 | Debug|x64 = Debug|x64 13 | Debug|x86 = Debug|x86 14 | Release|x64 = Release|x64 15 | Release|x86 = Release|x86 16 | EndGlobalSection 17 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 18 | {C65777FC-3FD2-406F-84EF-C5D8D68F15BB}.Debug|x64.ActiveCfg = Debug|x64 19 | {C65777FC-3FD2-406F-84EF-C5D8D68F15BB}.Debug|x64.Build.0 = Debug|x64 20 | {C65777FC-3FD2-406F-84EF-C5D8D68F15BB}.Debug|x86.ActiveCfg = Debug|Win32 21 | {C65777FC-3FD2-406F-84EF-C5D8D68F15BB}.Debug|x86.Build.0 = Debug|Win32 22 | {C65777FC-3FD2-406F-84EF-C5D8D68F15BB}.Release|x64.ActiveCfg = Release|x64 23 | {C65777FC-3FD2-406F-84EF-C5D8D68F15BB}.Release|x64.Build.0 = Release|x64 24 | {C65777FC-3FD2-406F-84EF-C5D8D68F15BB}.Release|x86.ActiveCfg = Release|Win32 25 | {C65777FC-3FD2-406F-84EF-C5D8D68F15BB}.Release|x86.Build.0 = Release|Win32 26 | {2BC833B9-98A3-4117-90EF-0FF768000C20}.Debug|x64.ActiveCfg = Debug|x64 27 | {2BC833B9-98A3-4117-90EF-0FF768000C20}.Debug|x64.Build.0 = Debug|x64 28 | {2BC833B9-98A3-4117-90EF-0FF768000C20}.Debug|x86.ActiveCfg = Debug|Win32 29 | {2BC833B9-98A3-4117-90EF-0FF768000C20}.Debug|x86.Build.0 = Debug|Win32 30 | {2BC833B9-98A3-4117-90EF-0FF768000C20}.Release|x64.ActiveCfg = Release|x64 31 | {2BC833B9-98A3-4117-90EF-0FF768000C20}.Release|x64.Build.0 = Release|x64 32 | {2BC833B9-98A3-4117-90EF-0FF768000C20}.Release|x86.ActiveCfg = Release|Win32 33 | {2BC833B9-98A3-4117-90EF-0FF768000C20}.Release|x86.Build.0 = Release|Win32 34 | EndGlobalSection 35 | GlobalSection(SolutionProperties) = preSolution 36 | HideSolutionNode = FALSE 37 | EndGlobalSection 38 | EndGlobal 39 | --------------------------------------------------------------------------------