├── .gitignore ├── .travis.yml ├── HOWTO_RELEASE.md ├── LICENSE ├── MANIFEST.in ├── MSTClustering.ipynb ├── Makefile ├── README.md ├── images └── SimpleClustering.png ├── mst_clustering ├── __init__.py ├── _mst_clustering.py └── tests │ ├── __init__.py │ └── test_mst_clustering.py ├── paper ├── mst_example.png ├── paper.bib ├── paper.json └── paper.md └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | build/ 3 | dist/ 4 | MANIFEST 5 | 6 | .coverage 7 | 8 | #* 9 | *~ 10 | 11 | .ipynb_checkpoints -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | # sudo false implies containerized builds 4 | sudo: false 5 | 6 | python: 7 | - 2.7 8 | - 3.4 9 | - 3.5 10 | 11 | env: 12 | global: 13 | # Directory where tests are run from 14 | - TEST_DIR=/tmp/mst_clustering 15 | - CONDA_DEPS="scikit-learn nose" 16 | - PIP_DEPS="coveralls" 17 | matrix: 18 | - EXTRA_DEPS="" 19 | 20 | before_install: 21 | - export MINICONDA=$HOME/miniconda 22 | - export PATH="$MINICONDA/bin:$PATH" 23 | - hash -r 24 | - wget http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh 25 | - bash miniconda.sh -b -f -p $MINICONDA 26 | - conda config --set always_yes yes 27 | - conda update conda 28 | - conda info -a 29 | - conda install python=$TRAVIS_PYTHON_VERSION $CONDA_DEPS $EXTRA_DEPS 30 | - travis_retry pip install $PIP_DEPS 31 | 32 | install: 33 | - python setup.py install 34 | 35 | script: 36 | - mkdir -p $TEST_DIR 37 | - cd $TEST_DIR && nosetests -v --with-coverage --cover-package=mst_clustering mst_clustering 38 | 39 | after_success: 40 | - coveralls 41 | -------------------------------------------------------------------------------- /HOWTO_RELEASE.md: -------------------------------------------------------------------------------- 1 | # How to Release 2 | 3 | Here's a quick step-by-step for cutting a new release of mst_clustering. 4 | 5 | ## Pre-release 6 | 7 | 1. update version in ``mst_clustering/__init__.py`` to, e.g. "0.1" 8 | 9 | 2. create a release tag; e.g. 10 | ``` 11 | $ git tag -a v0.1 -m 'version 0.1 release' 12 | ``` 13 | 14 | 3. push the commits and tag to github 15 | 16 | 4. confirm that CI tests pass on github 17 | 18 | 5. under "tags" on github, update the release notes 19 | 20 | 21 | ## Publishing the Release 22 | 23 | 1. push the new release to PyPI (requires jakevdp's permissions) 24 | ``` 25 | $ python setup.py sdist upload 26 | ``` 27 | 28 | ## Post-release 29 | 30 | 1. update version in ``mst_clustering/__init__.py`` to next version; e.g. '0.2.dev0' 31 | 32 | 2. push changes to github 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Jake Vanderplas 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.md 2 | include *.py 3 | recursive-include mst_clustering *.py 4 | include MSTClustering.ipynb 5 | include LICENSE 6 | include README 7 | include Makefile -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: 2 | nosetests mst_clustering 3 | 4 | doctest: 5 | nosetests --with-doctest mst_clustering 6 | 7 | test-coverage: 8 | nosetests --with-coverage --cover-package=mst_clustering 9 | 10 | test-coverage-html: 11 | nosetests --with-coverage --cover-html --cover-package=mst_clustering 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Minimum Spanning Tree Clustering 2 | 3 | [![build status](http://img.shields.io/travis/jakevdp/mst_clustering/master.svg?style=flat)](https://travis-ci.org/jakevdp/mst_clustering) 4 | [![version status](http://img.shields.io/pypi/v/mst_clustering.svg?style=flat)](https://pypi.python.org/pypi/mst_clustering) 5 | [![license](http://img.shields.io/badge/license-BSD-blue.svg?style=flat)](https://github.com/jakevdp/mst_clustering/blob/master/LICENSE) 6 | [![DOI](https://zenodo.org/badge/doi/10.5281/zenodo.50995.svg)](http://dx.doi.org/10.5281/zenodo.50995) 7 | [![JOSS](http://joss.theoj.org/papers/10.21105/joss.00012/status.svg)](http://joss.theoj.org/papers/10.21105/joss.00012) 8 | 9 | 10 | This package implements a simple scikit-learn style estimator for clustering 11 | with a minimum spanning tree. 12 | 13 | ## Motivation 14 | 15 | Automated clustering can be an important means of identifying structure in data, 16 | but many of the more popular clustering algorithms do not perform well in the 17 | presence of background noise. The clustering algorithm implemented here, based 18 | on a trimmed Euclidean Minimum Spanning Tree, can be useful in this case. 19 | 20 | ## Example 21 | 22 | The API of the ``mst_clustering`` code is designed for compatibility with 23 | the [scikit-learn](http://scikit-learn.org) project. 24 | 25 | ```python 26 | from mst_clustering import MSTClustering 27 | from sklearn.datasets import make_blobs 28 | import matplotlib.pyplot as plt 29 | 30 | # create some data with four clusters 31 | X, y = make_blobs(200, centers=4, random_state=42) 32 | 33 | # predict the labels with the MST algorithm 34 | model = MSTClustering(cutoff_scale=2) 35 | labels = model.fit_predict(X) 36 | 37 | # plot the results 38 | plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='rainbow'); 39 | ``` 40 | 41 | ![Simple Clustering Plot](https://raw.githubusercontent.com/jakevdp/mst_clustering/master/images/SimpleClustering.png) 42 | 43 | For a detailed explanation of the algorithm and a more interesting example of it in action, see the [MST Clustering Notebook](http://nbviewer.jupyter.org/github/jakevdp/mst_clustering/blob/master/MSTClustering.ipynb). 44 | 45 | ## Installation & Requirements 46 | 47 | The ``mst_clustering`` package itself is fairly lightweight. It is tested on 48 | Python 2.7 and 3.4-3.5, and depends on the following packages: 49 | 50 | - [numpy](http://numpy.org) 51 | - [scipy](http://scipy.org) 52 | - [scikit-learn](http://scikit-learn.org) 53 | 54 | Using the cross-platform [conda](http://conda.pydata.org/miniconda.html) 55 | package manager, these requirements can be installed as follows: 56 | 57 | ``` 58 | $ conda install numpy scipy scikit-learn 59 | ``` 60 | 61 | Finally, the current release of ``mst_clustering`` can be installed using ``pip``: 62 | ``` 63 | $ conda install pip # if using conda 64 | $ pip install mst_clustering 65 | ``` 66 | 67 | To install ``mst_clustering`` from source, first download the source repository and then run 68 | ``` 69 | $ python setup.py install 70 | ``` 71 | 72 | ## Contributing & Reporting Issues 73 | Bug reports, questions, suggestions, and contributions are welcome. 74 | For these, please make use the 75 | [Issues](https://github.com/jakevdp/mst_clustering/issues) 76 | or [Pull Requests](https://github.com/jakevdp/mst_clustering/pulls) 77 | associated with this repository. 78 | 79 | ## Citing 80 | If you use this code in an academic publication, please consider 81 | citing this [JOSS Paper](http://joss.theoj.org/papers/10.21105/joss.00012). 82 | -------------------------------------------------------------------------------- /images/SimpleClustering.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakevdp/mst_clustering/6f1fa76bfd04bfd22119edb85d67ef07ef092364/images/SimpleClustering.png -------------------------------------------------------------------------------- /mst_clustering/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.dev0" 2 | 3 | from ._mst_clustering import MSTClustering 4 | -------------------------------------------------------------------------------- /mst_clustering/_mst_clustering.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimum Spanning Tree Clustering 3 | """ 4 | from __future__ import division 5 | 6 | import numpy as np 7 | 8 | from scipy import sparse 9 | from scipy.sparse.csgraph import minimum_spanning_tree, connected_components 10 | from scipy.sparse.csgraph._validation import validate_graph 11 | from sklearn.utils import check_array 12 | 13 | from sklearn.base import BaseEstimator, ClusterMixin 14 | from sklearn.neighbors import kneighbors_graph 15 | from sklearn.metrics import pairwise_distances 16 | 17 | 18 | class MSTClustering(BaseEstimator, ClusterMixin): 19 | """Minimum Spanning Tree Clustering 20 | 21 | Parameters 22 | ---------- 23 | cutoff : float, int, optional 24 | either the number of edges to cut (if cutoff >= 1) or the fraction of 25 | edges to cut (if 0 < cutoff < 1). See also the ``cutoff_scale`` 26 | parameter. 27 | cutoff_scale : float, optional 28 | minimum size of edges. All edges larger than cutoff_scale will be 29 | removed (see also ``cutoff`` parameter). 30 | min_cluster_size : int (default: 1) 31 | minimum number of points per cluster. Points belonging to smaller 32 | clusters will be assigned to the background. 33 | approximate : bool, optional (default: True) 34 | If True, then compute the approximate minimum spanning tree using 35 | n_neighbors nearest neighbors. If False, then compute the full 36 | O[N^2] edges (see Notes, below). 37 | n_neighbors : int, optional (default: 20) 38 | maximum number of neighbors of each point used for approximate 39 | Euclidean minimum spanning tree (MST) algorithm. Referenced only 40 | if ``approximate`` is False. See Notes below. 41 | metric : string (default "euclidean") 42 | Distance metric to use in computing distances. If "precomputed", then 43 | input is a [n_samples, n_samples] matrix of pairwise distances (either 44 | sparse, or dense with NaN/inf indicating missing edges) 45 | metric_params : dict or None (optional) 46 | dictionary of parameters passed to the metric. See documentation of 47 | sklearn.neighbors.NearestNeighbors for details. 48 | 49 | Attributes 50 | ---------- 51 | full_tree_ : sparse array, shape (n_samples, n_samples) 52 | Full minimum spanning tree over the fit data 53 | T_trunc_ : sparse array, shape (n_samples, n_samples) 54 | Non-connected graph over the final clusters 55 | labels_: array, length n_samples 56 | Labels of each point 57 | 58 | Notes 59 | ----- 60 | This routine uses an approximate Euclidean minimum spanning tree (MST) 61 | to perform hierarchical clustering. A true Euclidean minimum spanning 62 | tree naively costs O[N^3]. Graph traversal algorithms only help so much, 63 | because all N^2 edges must be used as candidates. In this approximate 64 | algorithm, we use k << N edges from each point, so that the cost is only 65 | O[Nk log(Nk)]. For k = N, the approximation is exact; in practice for 66 | well-behaved data sets, the result is exact for k << N. 67 | """ 68 | def __init__(self, cutoff=None, cutoff_scale=None, min_cluster_size=1, 69 | approximate=True, n_neighbors=20, 70 | metric='euclidean', metric_params=None): 71 | self.cutoff = cutoff 72 | self.cutoff_scale = cutoff_scale 73 | self.min_cluster_size = min_cluster_size 74 | self.approximate = approximate 75 | self.n_neighbors = n_neighbors 76 | self.metric = metric 77 | self.metric_params = metric_params 78 | 79 | def fit(self, X, y=None): 80 | """Fit the clustering model 81 | 82 | Parameters 83 | ---------- 84 | X : array_like 85 | the data to be clustered: shape = [n_samples, n_features] 86 | """ 87 | if self.cutoff is None and self.cutoff_scale is None: 88 | raise ValueError("Must specify either cutoff or cutoff_frac") 89 | 90 | # Compute the distance-based graph G from the points in X 91 | if self.metric == 'precomputed': 92 | # Input is already a graph. Copy if sparse 93 | # so we can overwrite for efficiency below. 94 | self.X_fit_ = None 95 | G = validate_graph(X, directed=True, 96 | csr_output=True, dense_output=False, 97 | copy_if_sparse=True, null_value_in=np.inf) 98 | elif not self.approximate: 99 | X = check_array(X) 100 | self.X_fit_ = X 101 | kwds = self.metric_params or {} 102 | G = pairwise_distances(X, metric=self.metric, **kwds) 103 | G = validate_graph(G, directed=True, 104 | csr_output=True, dense_output=False, 105 | copy_if_sparse=True, null_value_in=np.inf) 106 | else: 107 | # generate a sparse graph using n_neighbors of each point 108 | X = check_array(X) 109 | self.X_fit_ = X 110 | n_neighbors = min(self.n_neighbors, X.shape[0] - 1) 111 | G = kneighbors_graph(X, n_neighbors=n_neighbors, 112 | mode='distance', 113 | metric=self.metric, 114 | metric_params=self.metric_params) 115 | 116 | # HACK to keep explicit zeros (minimum spanning tree removes them) 117 | zero_fillin = G.data[G.data > 0].min() * 1E-8 118 | G.data[G.data == 0] = zero_fillin 119 | 120 | # Compute the minimum spanning tree of this graph 121 | self.full_tree_ = minimum_spanning_tree(G, overwrite=True) 122 | 123 | # undo the hack to bring back explicit zeros 124 | self.full_tree_[self.full_tree_ == zero_fillin] = 0 125 | 126 | # Partition the data by the cutoff 127 | N = G.shape[0] - 1 128 | if self.cutoff is None: 129 | i_cut = N 130 | elif 0 <= self.cutoff < 1: 131 | i_cut = int((1 - self.cutoff) * N) 132 | elif self.cutoff >= 1: 133 | i_cut = int(N - self.cutoff) 134 | else: 135 | raise ValueError('self.cutoff must be positive, not {0}' 136 | ''.format(self.cutoff)) 137 | 138 | # create the mask; we zero-out values where the mask is True 139 | N = len(self.full_tree_.data) 140 | if i_cut < 0: 141 | mask = np.ones(N, dtype=bool) 142 | elif i_cut >= N: 143 | mask = np.zeros(N, dtype=bool) 144 | else: 145 | mask = np.ones(N, dtype=bool) 146 | part = np.argpartition(self.full_tree_.data, i_cut) 147 | mask[part[:i_cut]] = False 148 | 149 | # additionally cut values above the ``cutoff_scale`` 150 | if self.cutoff_scale is not None: 151 | mask |= (self.full_tree_.data > self.cutoff_scale) 152 | 153 | # Trim the tree 154 | cluster_graph = self.full_tree_.copy() 155 | 156 | # Eliminate zeros from cluster_graph for efficiency. 157 | # We want to do this: 158 | # cluster_graph.data[mask] = 0 159 | # cluster_graph.eliminate_zeros() 160 | # but there could be explicit zeros in our data! 161 | # So we call eliminate_zeros() with a stand-in data array, 162 | # then replace the data when we're finished. 163 | original_data = cluster_graph.data 164 | cluster_graph.data = np.arange(1, len(cluster_graph.data) + 1) 165 | cluster_graph.data[mask] = 0 166 | cluster_graph.eliminate_zeros() 167 | cluster_graph.data = original_data[cluster_graph.data.astype(int) - 1] 168 | 169 | # find connected components 170 | n_components, labels = connected_components(cluster_graph, 171 | directed=False) 172 | 173 | # remove clusters with fewer than min_cluster_size 174 | counts = np.bincount(labels) 175 | to_remove = np.where(counts < self.min_cluster_size)[0] 176 | 177 | if len(to_remove) > 0: 178 | for i in to_remove: 179 | labels[labels == i] = -1 180 | _, labels = np.unique(labels, return_inverse=True) 181 | labels -= 1 # keep -1 labels the same 182 | 183 | # update cluster_graph by eliminating non-clusters 184 | # operationally, this means zeroing-out rows & columns where 185 | # the label is negative. 186 | I = sparse.eye(len(labels)) 187 | I.data[0, labels < 0] = 0 188 | 189 | # we could just do this: 190 | # cluster_graph = I * cluster_graph * I 191 | # but we want to be able to eliminate the zeros, so we use 192 | # the same indexing trick as above 193 | original_data = cluster_graph.data 194 | cluster_graph.data = np.arange(1, len(cluster_graph.data) + 1) 195 | cluster_graph = I * cluster_graph * I 196 | cluster_graph.eliminate_zeros() 197 | cluster_graph.data = original_data[cluster_graph.data.astype(int) - 1] 198 | 199 | self.labels_ = labels 200 | self.cluster_graph_ = cluster_graph 201 | return self 202 | 203 | def get_graph_segments(self, full_graph=False): 204 | """Convenience routine to get graph segments 205 | 206 | This is useful for visualization of the graph underlying the algorithm. 207 | 208 | Parameters 209 | ---------- 210 | full_graph : bool (default: False) 211 | If True, return the full graph of connections. Otherwise return 212 | the truncated graph representing clusters. 213 | 214 | Returns 215 | ------- 216 | segments : tuple of ndarrays 217 | the coordinates representing the graph. The tuple is of length 218 | n_features, and each array is of size (n_features, n_edges). 219 | For n_features=2, the graph can be visualized in matplotlib with, 220 | e.g. ``plt.plot(segments[0], segments[1], '-k')`` 221 | """ 222 | if not hasattr(self, 'X_fit_'): 223 | raise ValueError("Must call fit() before get_graph_segments()") 224 | if self.metric == 'precomputed': 225 | raise ValueError("Cannot use ``get_graph_segments`` " 226 | "with precomputed metric.") 227 | 228 | n_samples, n_features = self.X_fit_.shape 229 | 230 | if full_graph: 231 | G = sparse.coo_matrix(self.full_tree_) 232 | else: 233 | G = sparse.coo_matrix(self.cluster_graph_) 234 | 235 | return tuple(np.vstack(arrs) for arrs in zip(self.X_fit_[G.row].T, 236 | self.X_fit_[G.col].T)) 237 | -------------------------------------------------------------------------------- /mst_clustering/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakevdp/mst_clustering/6f1fa76bfd04bfd22119edb85d67ef07ef092364/mst_clustering/tests/__init__.py -------------------------------------------------------------------------------- /mst_clustering/tests/test_mst_clustering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import (assert_, assert_equal, assert_allclose, 3 | assert_raises_regex) 4 | 5 | from nose import SkipTest 6 | 7 | from sklearn.datasets import make_blobs 8 | from sklearn.neighbors import kneighbors_graph 9 | from sklearn.metrics import pairwise_distances 10 | 11 | from mst_clustering import MSTClustering 12 | 13 | 14 | def test_simple_blobs(): 15 | X, y = make_blobs(100, random_state=42) 16 | 17 | def _check_params(kwds): 18 | y_pred = MSTClustering(n_neighbors=100, **kwds).fit_predict(X) 19 | assert_equal(len(np.unique(y_pred)), 3) 20 | assert_allclose([np.std(y[y == i]) for i in range(3)], 0) 21 | 22 | for kwds in [dict(cutoff=2), dict(cutoff=0.02), dict(cutoff_scale=2.5)]: 23 | yield _check_params, kwds 24 | 25 | 26 | def test_n_clusters(): 27 | N = 30 28 | rng = np.random.RandomState(42) 29 | X = rng.rand(N, 3) 30 | 31 | def _check_n(n): 32 | y_pred = MSTClustering(cutoff=n, approximate=False).fit_predict(X) 33 | assert_equal(len(np.unique(y_pred)), n + 1) 34 | 35 | for n in range(30): 36 | yield _check_n, n 37 | 38 | 39 | def test_n_clusters_approximate(): 40 | N = 30 41 | rng = np.random.RandomState(42) 42 | X = rng.rand(N, 3) 43 | 44 | def _check_n(n): 45 | y_pred = MSTClustering(cutoff=n, 46 | n_neighbors=2, 47 | approximate=True).fit_predict(X) 48 | assert_equal(len(np.unique(y_pred)), n + 1) 49 | 50 | # due to approximation, there are 3 clusters for n in (1, 2) 51 | for n in range(3, 30): 52 | yield _check_n, n 53 | 54 | 55 | def test_explicit_zeros(): 56 | N = 30 57 | rng = np.random.RandomState(42) 58 | X = rng.rand(N, 3) 59 | X[-5:] = X[:5] 60 | 61 | def _check_n(n): 62 | y_pred = MSTClustering(cutoff=n).fit_predict(X) 63 | assert_equal(len(np.unique(y_pred)), n + 1) 64 | 65 | for n in range(30): 66 | yield _check_n, n 67 | 68 | 69 | def test_precomputed_metric(): 70 | N = 30 71 | n_neighbors = 10 72 | rng = np.random.RandomState(42) 73 | X = rng.rand(N, 3) 74 | 75 | G_sparse = kneighbors_graph(X, n_neighbors=n_neighbors, mode='distance') 76 | G_dense = G_sparse.toarray() 77 | G_dense[G_dense == 0] = np.nan 78 | 79 | kwds = dict(cutoff=0.1) 80 | y1 = MSTClustering(n_neighbors=n_neighbors, **kwds).fit_predict(X) 81 | y2 = MSTClustering(metric='precomputed', **kwds).fit_predict(G_sparse) 82 | y3 = MSTClustering(metric='precomputed', **kwds).fit_predict(G_dense) 83 | 84 | assert_allclose(y1, y2) 85 | assert_allclose(y2, y3) 86 | 87 | 88 | def test_precomputed_metric_with_duplicates(): 89 | N = 30 90 | n_neighbors = N - 1 91 | rng = np.random.RandomState(42) 92 | 93 | # make data with duplicate points 94 | X = rng.rand(N, 3) 95 | X[-5:] = X[:5] 96 | 97 | # compute sparse distances 98 | G_sparse = kneighbors_graph(X, n_neighbors=n_neighbors, 99 | mode='distance') 100 | 101 | # compute dense distances 102 | G_dense = pairwise_distances(X, X) 103 | 104 | kwds = dict(cutoff=0.1) 105 | y1 = MSTClustering(n_neighbors=n_neighbors, **kwds).fit_predict(X) 106 | y2 = MSTClustering(metric='precomputed', **kwds).fit_predict(G_sparse) 107 | y3 = MSTClustering(metric='precomputed', **kwds).fit_predict(G_dense) 108 | 109 | assert_allclose(y1, y2) 110 | assert_allclose(y2, y3) 111 | 112 | 113 | def test_min_cluster_size(): 114 | N = 30 115 | rng = np.random.RandomState(42) 116 | X = rng.rand(N, 3) 117 | 118 | def _check(n, min_cluster_size): 119 | y_pred = MSTClustering(cutoff=n, 120 | n_neighbors=2, 121 | min_cluster_size=min_cluster_size, 122 | approximate=True).fit_predict(X) 123 | labels, counts = np.unique(y_pred, return_counts=True) 124 | counts = counts[labels >= 0] 125 | if len(counts): 126 | assert_(counts.min() >= min_cluster_size) 127 | 128 | # due to approximation, there are 3 clusters for n in (1, 2) 129 | for n in range(3, 30, 5): 130 | for min_cluster_size in [1, 3, 5]: 131 | yield _check, n, min_cluster_size 132 | 133 | 134 | def test_precomputed(): 135 | X, y = make_blobs(100, random_state=42) 136 | D = pairwise_distances(X) 137 | 138 | mst1 = MSTClustering(cutoff=0.1) 139 | mst2 = MSTClustering(cutoff=0.1, metric='precomputed') 140 | 141 | assert_equal(mst1.fit_predict(X), 142 | mst2.fit_predict(D)) 143 | 144 | 145 | def test_bad_arguments(): 146 | X, y = make_blobs(100, random_state=42) 147 | 148 | mst = MSTClustering() 149 | assert_raises_regex(ValueError, 150 | "Must specify either cutoff or cutoff_frac", 151 | mst.fit, X, y) 152 | 153 | mst = MSTClustering(cutoff=-1) 154 | assert_raises_regex(ValueError, "cutoff must be positive", mst.fit, X) 155 | 156 | mst = MSTClustering() 157 | msg = "Must call fit\(\) before get_graph_segments()" 158 | assert_raises_regex(ValueError, msg, mst.get_graph_segments) 159 | 160 | mst = MSTClustering(cutoff=0, metric='precomputed') 161 | mst.fit(pairwise_distances(X)) 162 | msg = "Cannot use ``get_graph_segments`` with precomputed metric." 163 | assert_raises_regex(ValueError, msg, mst.get_graph_segments) 164 | 165 | 166 | def test_graph_segments_shape(): 167 | def check_shape(ndim, cutoff, N=10): 168 | X = np.random.rand(N, ndim) 169 | mst = MSTClustering(cutoff=cutoff).fit(X) 170 | 171 | segments = mst.get_graph_segments() 172 | print(ndim, cutoff, segments[0].shape) 173 | assert len(segments) == ndim 174 | assert all(seg.shape == (2, N - 1 - cutoff) for seg in segments) 175 | 176 | segments = mst.get_graph_segments(full_graph=True) 177 | print(segments[0].shape) 178 | assert len(segments) == ndim 179 | assert all(seg.shape == (2, N - 1) for seg in segments) 180 | 181 | for N in [10, 15]: 182 | for ndim in [1, 2, 3]: 183 | for cutoff in [0, 1, 2]: 184 | yield check_shape, ndim, cutoff, N 185 | 186 | 187 | def check_graph_segments_vals(): 188 | X = np.arange(5)[:, None] ** 2 189 | mst = MSTClustering(cutoff=0).fit(X) 190 | segments = mst.get_graph_segments() 191 | assert len(segments) == 1 192 | assert_allclose(segments[0], 193 | [[0, 4, 4, 9], 194 | [1, 1, 9, 16]]) 195 | 196 | 197 | # this fails for silly reasons currently; we'll leave it out. 198 | def __test_estimator_checks(): 199 | try: 200 | from sklearn.utils.estimator_checks import check_estimator 201 | except ImportError: 202 | raise SkipTest("need scikit-learn 0.17+ for check_estimator()") 203 | 204 | check_estimator(MSTClustering) 205 | -------------------------------------------------------------------------------- /paper/mst_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakevdp/mst_clustering/6f1fa76bfd04bfd22119edb85d67ef07ef092364/paper/mst_example.png -------------------------------------------------------------------------------- /paper/paper.bib: -------------------------------------------------------------------------------- 1 | @article{scikit-learn, 2 | title={Scikit-learn: Machine learning in Python}, 3 | author={Pedregosa, Fabian and Varoquaux, Ga{\"e}l and Gramfort, Alexandre and Michel, Vincent and Thirion, Bertrand and Grisel, Olivier and Blondel, Mathieu and Prettenhofer, Peter and Weiss, Ron and Dubourg, Vincent and others}, 4 | journal={The Journal of Machine Learning Research}, 5 | volume={12}, 6 | pages={2825--2830}, 7 | year={2011}, 8 | publisher={JMLR. org} 9 | } 10 | 11 | @article{scikit-learn-api, 12 | title={API design for machine learning software: experiences from the scikit-learn project}, 13 | author={Buitinck, Lars and Louppe, Gilles and Blondel, Mathieu and Pedregosa, Fabian and Mueller, Andreas and Grisel, Olivier and Niculae, Vlad and Prettenhofer, Peter and Gramfort, Alexandre and Grobler, Jaques and others}, 14 | journal={arXiv preprint arXiv:1309.0238}, 15 | year={2013} 16 | } 17 | 18 | @book{ivezic2014, 19 | title={Statistics, Data Mining, and Machine Learning in Astronomy: A Practical Python Guide for the Analysis of Survey Data}, 20 | author={Ivezi{\'c}, {\v{Z}}eljko and Connolly, Andrew J and VanderPlas, Jacob T and Gray, Alexander}, 21 | year={2014}, 22 | publisher={Princeton University Press} 23 | } 24 | 25 | @Misc{scipy, 26 | author = {Eric Jones and Travis Oliphant and Pearu Peterson and others}, 27 | title = {{SciPy}: Open source scientific tools for {Python}}, 28 | year = {2001--}, 29 | url = "http://www.scipy.org/", 30 | note = {[Online; accessed 2016-05-04]} 31 | } 32 | -------------------------------------------------------------------------------- /paper/paper.json: -------------------------------------------------------------------------------- 1 | { 2 | "@context": "https://raw.githubusercontent.com/mbjones/codemeta/master/codemeta.jsonld", 3 | "@type": "Code", 4 | "author": [ 5 | { 6 | "@id": "0000-0002-9623-3401", 7 | "@type": "Person", 8 | "email": "jakevdp@uw.edu", 9 | "name": "Jake VanderPlas", 10 | "affiliation": "University of Washington eScience Institute" 11 | } 12 | ], 13 | "identifier": "https://zenodo.org/record/50995#.Vyp9DBUrJBw", 14 | "codeRepository": "http://github.com/jakevdp/mst_clustering", 15 | "datePublished": "2016-05-04", 16 | "dateModified": "2016-05-04", 17 | "dateCreated": "2016-05-04", 18 | "description": "Clustering via Euclidean Minimum Spanning Trees", 19 | "keywords": "machine learning", 20 | "license": "BSD", 21 | "title": "mst_clustering", 22 | "version": "v1.0" 23 | } 24 | -------------------------------------------------------------------------------- /paper/paper.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 'mst_clustering: Clustering via Euclidean Minimum Spanning Trees' 3 | tags: 4 | - machine learning 5 | - clustering 6 | authors: 7 | - name: Jake VanderPlas 8 | orcid: 0000-0002-9623-3401 9 | affiliation: University of Washington eScience Institute 10 | date: 04 May 2016 11 | bibliography: paper.bib 12 | --- 13 | 14 | # Summary 15 | 16 | This package contains a Python implementation of a clustering algorithm based 17 | on an efficiently-constructed approximate Euclidean minimum spanning tree 18 | (described in [@ivezic2014]). The method produces a Hierarchical clustering of 19 | input data, and is quite similar to single-linkage Agglomerative clustering. 20 | The advantage of this implementation is the ability to find significant clusters 21 | even in the presence of background noise, and is particularly useful for 22 | researchers hoping to detect structure in physical data. 23 | 24 | The code makes use of tools within SciPy [@scipy] and scikit-learn [@scikit-learn], 25 | and is designed for compatibility with the scikit-learn API [@scikit-learn-api]. 26 | 27 | -![Simple Clustering Example](mst_example.png) 28 | 29 | # References 30 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import re 4 | 5 | from distutils.core import setup 6 | 7 | 8 | def read(path, encoding='utf-8'): 9 | path = os.path.join(os.path.dirname(__file__), path) 10 | with io.open(path, encoding=encoding) as fp: 11 | return fp.read() 12 | 13 | 14 | def version(path): 15 | """Obtain the packge version from a python file e.g. pkg/__init__.py 16 | 17 | See . 18 | """ 19 | version_file = read(path) 20 | version_match = re.search(r"""^__version__ = ['"]([^'"]*)['"]""", 21 | version_file, re.M) 22 | if version_match: 23 | return version_match.group(1) 24 | raise RuntimeError("Unable to find version string.") 25 | 26 | 27 | DESCRIPTION = "Clustering with Minimum Spanning Trees" 28 | LONG_DESCRIPTION = """ 29 | mst_clustering: Clustering with Minimum Spanning Trees 30 | ====================================================== 31 | 32 | This package implements a scikit-learn style estimator for computing clustering 33 | with a minimum spanning tree. 34 | 35 | For more information, visit http://github.com/jakevdp/mst_clustering 36 | """ 37 | NAME = "mst_clustering" 38 | AUTHOR = "Jake VanderPlas" 39 | AUTHOR_EMAIL = "jakevdp@uw.edu" 40 | MAINTAINER = "Jake VanderPlas" 41 | MAINTAINER_EMAIL = "jakevdp@uw.edu" 42 | URL = 'http://github.com/jakevdp/mst_clustering' 43 | DOWNLOAD_URL = 'http://github.com/jakevdp/mst_clustering' 44 | LICENSE = 'new BSD' 45 | 46 | VERSION = version('mst_clustering/__init__.py') 47 | 48 | setup(name=NAME, 49 | version=VERSION, 50 | description=DESCRIPTION, 51 | long_description=LONG_DESCRIPTION, 52 | author=AUTHOR, 53 | author_email=AUTHOR_EMAIL, 54 | maintainer=MAINTAINER, 55 | maintainer_email=MAINTAINER_EMAIL, 56 | url=URL, 57 | download_url=DOWNLOAD_URL, 58 | license=LICENSE, 59 | packages=['mst_clustering', 60 | 'mst_clustering.tests', 61 | ], 62 | classifiers=[ 63 | 'Development Status :: 4 - Beta', 64 | 'Environment :: Console', 65 | 'Intended Audience :: Science/Research', 66 | 'License :: OSI Approved :: BSD License', 67 | 'Natural Language :: English', 68 | 'Programming Language :: Python :: 2.7', 69 | 'Programming Language :: Python :: 3.4', 70 | 'Programming Language :: Python :: 3.5'], 71 | ) 72 | --------------------------------------------------------------------------------