├── bkmeans ├── __init__.py └── bkmeans.py ├── LICENSE ├── .gitignore ├── README.md └── setup.py /bkmeans/__init__.py: -------------------------------------------------------------------------------- 1 | from .bkmeans import BKMeans 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Bernd Fritzke 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so, 8 | subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Breathing *K*-Means Algorithm 2 | 3 | An approximation algorithm for the *k*-means problem that (on average) is **better** (higher solution quality) and **faster** (lower CPU time usage) than ***k*-means++** (in the widely-used implementation from `scikit-learn` and with `n_init=10`, the longtime default value for repetitions). 4 | 5 | **Remark**: Version 1.4 of `scikit-learn` (released 1/2024) introduced a new default value `n_init=1` for ***k*-means++**. As a consequence, the advantage of **breathing *k*-means** in solution quality has become even larger but ***k*-means++** is faster now (by sacrificing solution quality). 6 | 7 | **Technical Report:** 8 | https://arxiv.org/abs/2006.15666 9 | 10 | **Example Code:** [Google Colab notebook](https://colab.research.google.com/drive/1tT3lPJQqSIzQ7NuIhRa8DlSU11lO9jgo?usp=sharing) 11 | 12 | ## API 13 | The included class **BKMeans** is subclassed from [scikit-learn's **KMeans** class](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html) 14 | and has, therefore, **the same API**. It can be used as a plug-in replacement for scikit-learn's **KMeans**. 15 | 16 | There is one new parameter that can be ignored (meaning: left at default) for normal usage: 17 | 18 | * *m* (breathing depth), default: 5 19 | 20 | The parameter *m* can also be used, however, to generate faster ( 1 < *m* < 5) or better (*m*>5) solutions. For details see the above technical report. 21 | 22 | 23 | ## Release Notes 24 | ### Version 1.3 25 | * bug fix: BKMeans.predict(df) now matches BKMeans.labels_ 26 | * bug fix: BKMeans is now fully deterministic if parameter random_state is set to non-Null 27 | * compatibility with scikit-learn: unused parameter *y* was not set in the `fit` method 28 | 29 | Remark: the fixed bugs did not affect the quality of the codebook results in previous version 30 | 31 | ### Version 1.2 32 | * make use of the optional `sample_weight` parameter in the `fit` method 33 | * (contributed by Björn Wiescholek) 34 | 35 | ### Version 1.1 36 | * speed improvement by setting `n_init=1` by default 37 | * close centroids now defined by nearest neighbor criterion 38 | * parameter `theta` abolished 39 | 40 | ### Version 1.0 41 | * (initial release) 42 | * "close centroids" were based on distance and a parameter *theta* 43 | 44 | ## Example 1: running on a simple random data set 45 | Code: 46 | ```python 47 | import numpy as np 48 | from bkmeans import BKMeans 49 | 50 | # generate random data set 51 | X=np.random.rand(1000,2) 52 | 53 | # create BKMeans instance 54 | bkm = BKMeans(n_clusters=100) 55 | 56 | # run the algorithm 57 | bkm.fit(X) 58 | 59 | # print SSE (inertia in scikit-learn terms) 60 | print(bkm.inertia_) 61 | ``` 62 | Output: 63 | ``` 64 | 1.1775040547902602 65 | ``` 66 | 67 | ## Example 2: comparison with *k*-means++ (multiple runs) 68 | Code: 69 | ```python 70 | import numpy as np 71 | from sklearn.cluster import KMeans 72 | from bkmeans import BKMeans 73 | import time 74 | 75 | # random 2D data set 76 | X=np.random.rand(1000,2) 77 | 78 | # number of centroids 79 | k=100 80 | 81 | for i in range(5): 82 | # kmeans++ 83 | kmp = KMeans(n_clusters=k, n_init=10) 84 | kmp.fit(X) 85 | 86 | # breathing k-means 87 | bkm = BKMeans(n_clusters=k) 88 | bkm.fit(X) 89 | 90 | # relative SSE improvement of bkm over km++ 91 | imp = 1 - bkm.inertia_/kmp.inertia_ 92 | print(f"SSE improvement over k-means++: {imp:.2%}") 93 | ``` 94 | Output: 95 | 96 | ``` 97 | SSE improvement over k-means++: 3.38% 98 | SSE improvement over k-means++: 4.16% 99 | SSE improvement over k-means++: 6.14% 100 | SSE improvement over k-means++: 6.79% 101 | SSE improvement over k-means++: 4.76% 102 | ``` 103 | 104 | 105 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: To use the 'upload' functionality of this file, you must: 5 | # $ pipenv install twine --dev 6 | 7 | import io 8 | import os 9 | import sys 10 | from shutil import rmtree 11 | 12 | from setuptools import find_packages, setup, Command 13 | 14 | # Package meta-data. 15 | NAME = 'bkmeans' 16 | DESCRIPTION = 'The breathing k-means algorithm' 17 | URL = 'https://github.com/gittar/bkmeans' 18 | EMAIL = 'fritzke@web.de' 19 | AUTHOR = 'Bernd Fritzke' 20 | REQUIRES_PYTHON = '>=3.7.0' 21 | VERSION = '1.2' 22 | 23 | # What packages are required for this module to be executed? 24 | REQUIRED = [ 25 | 'numpy', 'scikit-learn>=0.23.2', 'scipy' 26 | ] 27 | 28 | # What packages are optional? 29 | EXTRAS = { 30 | # 'fancy feature': ['django'], 31 | } 32 | 33 | # The rest you shouldn't have to touch too much :) 34 | # ------------------------------------------------ 35 | # Except, perhaps the License and Trove Classifiers! 36 | # If you do change the License, remember to change the Trove Classifier for that! 37 | 38 | here = os.path.abspath(os.path.dirname(__file__)) 39 | 40 | # Import the README and use it as the long-description. 41 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 42 | try: 43 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 44 | long_description = '\n' + f.read() 45 | except FileNotFoundError: 46 | long_description = DESCRIPTION 47 | 48 | # Load the package's __version__.py module as a dictionary. 49 | about = {} 50 | if not VERSION: 51 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 52 | with open(os.path.join(here, project_slug, '__version__.py')) as f: 53 | exec(f.read(), about) 54 | else: 55 | about['__version__'] = VERSION 56 | 57 | 58 | class UploadCommand(Command): 59 | """Support setup.py upload.""" 60 | 61 | description = 'Build and publish the package.' 62 | user_options = [] 63 | 64 | @staticmethod 65 | def status(s): 66 | """Prints things in bold.""" 67 | print('\033[1m{0}\033[0m'.format(s)) 68 | 69 | def initialize_options(self): 70 | pass 71 | 72 | def finalize_options(self): 73 | pass 74 | 75 | def run(self): 76 | try: 77 | self.status('Removing previous builds…') 78 | rmtree(os.path.join(here, 'dist')) 79 | except OSError: 80 | pass 81 | 82 | self.status('Building Source and Wheel (universal) distribution…') 83 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) 84 | 85 | self.status('Uploading the package to PyPI via Twine…') 86 | os.system('twine upload dist/*') 87 | 88 | self.status('Pushing git tags…') 89 | os.system('git tag v{0}'.format(about['__version__'])) 90 | os.system('git push --tags') 91 | 92 | sys.exit() 93 | 94 | 95 | # Where the magic happens: 96 | setup( 97 | name=NAME, 98 | version=about['__version__'], 99 | description=DESCRIPTION, 100 | long_description=long_description, 101 | long_description_content_type='text/markdown', 102 | author=AUTHOR, 103 | author_email=EMAIL, 104 | python_requires=REQUIRES_PYTHON, 105 | url=URL, 106 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 107 | # If your package is a single module, use this instead of 'packages': 108 | # py_modules=['mypackage'], 109 | 110 | # entry_points={ 111 | # 'console_scripts': ['mycli=mymodule:cli'], 112 | # }, 113 | install_requires=REQUIRED, 114 | extras_require=EXTRAS, 115 | include_package_data=True, 116 | license='MIT', 117 | classifiers=[ 118 | # Trove classifiers 119 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 120 | 'License :: OSI Approved :: MIT License', 121 | 'Programming Language :: Python', 122 | 'Programming Language :: Python :: 3', 123 | 'Programming Language :: Python :: 3.7', 124 | 'Programming Language :: Python :: Implementation :: CPython', 125 | 'Topic :: Scientific/Engineering', 126 | 'Topic :: Scientific/Engineering :: Artificial Intelligence' 127 | ], 128 | # $ setup.py publish support. 129 | cmdclass={ 130 | 'upload': UploadCommand, 131 | }, 132 | ) 133 | -------------------------------------------------------------------------------- /bkmeans/bkmeans.py: -------------------------------------------------------------------------------- 1 | # 2 | # breathing k-means reference implementation 3 | # (C) 2024 Bernd Fritzke 4 | # ADDED sample_weight by Björn Wiescholek 5 | # 6 | # common parameters: 7 | # X: data set 8 | # C: centroids 9 | 10 | import numpy as np 11 | from sklearn.cluster import KMeans 12 | from scipy.spatial.distance import cdist 13 | import math 14 | __version__="V1.3" 15 | 16 | rng = np.random.default_rng() 17 | class BKMeans(KMeans): 18 | @staticmethod 19 | def get_version(): 20 | return __version__ 21 | 22 | def __init__(self, m=5, n_init=1, **kwargs): 23 | """ m: breathing depth 24 | n_init: number of times k-means++ is run initially 25 | kwargs: arguments for scikit-learns KMeans """ 26 | global rng 27 | if "random_state" in kwargs: 28 | # initialize rng with random_state 29 | rng = np.random.default_rng(kwargs["random_state"]) 30 | super().__init__(n_init=n_init, **kwargs) 31 | self.m = min(m,self.n_clusters) # ensure m <= k 32 | 33 | def get_error(self, X, C): 34 | """compute error per centroid""" 35 | # squared distances between data and centroids 36 | dist = cdist(X, C, metric="sqeuclidean") 37 | # indices to nearest centroid 38 | dist_min = np.argmin(dist,axis=1) 39 | # distances to nearest centroid 40 | d1 = dist[np.arange(len(X)), dist_min] 41 | # aggregate error for each centroid 42 | return np.array([np.sum(d1[dist_min==i]) for i in range(len(C))]) 43 | 44 | def get_utility(self, X, C): 45 | """compute utility per centroid""" 46 | # squared distances between data and centroids 47 | dist = cdist(X, C, metric="sqeuclidean") 48 | # indices to nearest and 2nd-nearest centroid 49 | dist_srt = dist.argpartition(kth=1)[:,:2] 50 | # squared distances to nearest centroid 51 | d1 = dist[np.arange(len(X)), dist_srt[:, 0]] 52 | # squared distances to 2nd-nearest centroid 53 | d2 = dist[np.arange(len(X)), dist_srt[:, 1]] 54 | # utility 55 | util = d2-d1 56 | # aggregate utility for each centroid 57 | return np.array([np.sum(util[dist_srt[:, 0]==i]) for i in range(len(C))]) 58 | 59 | def _lloyd(self,C,X,sample_weight): 60 | """perform Lloyd's algorithm""" 61 | self.init = C # set cluster centers 62 | self.n_clusters = len(C) # set k-value 63 | super().fit(X=X, sample_weight=sample_weight) # Lloyd's algorithm, sets self.inertia_ (a.k.a. phi) 64 | 65 | def fit(self, X, y=None, sample_weight=None): 66 | """ compute k-means clustering via breathing k-means (if m > 0) """ 67 | 68 | # run k-means++ (unless 'init' parameter specifies differently) 69 | super().fit(X=X, sample_weight=sample_weight) # requires self.n_clusters >= 1 70 | # handle trivial case k=1 71 | if self.n_clusters == 1: 72 | return self 73 | # assertion: self.n_clusters >= 2 74 | m = self.m 75 | # memorize best error and codebook so far 76 | E_best = self.inertia_ 77 | C_best = self.cluster_centers_ 78 | tmp = self.n_init, self.init # store for compatibility with sklearn 79 | # no multiple trials from here on 80 | self.n_init = 1 81 | while m > 0: 82 | # add m centroids ("breathe in") and run Lloyd's algorithm 83 | self._lloyd(self._breathe_in(X, self.cluster_centers_, m),X,sample_weight) 84 | # delete m centroids ("breathe out") and run Lloyd's algorithm 85 | self._lloyd(self._breathe_out(X, self.cluster_centers_, m),X,sample_weight) 86 | if self.inertia_ < E_best*(1-self.tol): 87 | # improvement! update memorized best error and codebook so far 88 | E_best = self.inertia_ 89 | C_best = self.cluster_centers_ 90 | else: 91 | m -= 1 # no improvement: reduce "breathing depth" 92 | self.n_init, self.init = tmp # restore for compatibility with sklearn 93 | self.inertia_ = E_best 94 | self.cluster_centers_ = C_best 95 | self.labels_ = self.predict(X) 96 | return self 97 | 98 | def _breathe_in(self, X, C, m): 99 | """ add m centroids near centroids with large error""" 100 | E = self.get_error(X, C) # per centroid 101 | # indices of max error centroids 102 | max_err = (-E).argpartition(kth=m-1)[:m] 103 | # multiplicative small constant for offset vectors 104 | eps = 0.01 105 | # root-mean-square error 106 | RMSE=math.sqrt(np.sum(E)/len(X)) 107 | Dplus = C[max_err]+eps*RMSE*(rng.random((m,C.shape[1]))-0.5) 108 | # return enlarged codebook 109 | return np.concatenate([C, Dplus]) 110 | 111 | def _breathe_out(self, X, C, m): 112 | """ remove m centroids while avoiding large error increase""" 113 | U = self.get_utility(X, C) # per centroid 114 | useless_sorted = U.argsort() 115 | # mutual distances among centroids (kxk-matrix) 116 | c_dist = cdist(C, C, metric="sqeuclidean") 117 | # index of nearest neighbor for each centroid 118 | nearest_neighbor=c_dist.argpartition(kth=1)[:,1] 119 | Dminus = set() # index set of centroids to remove 120 | Frozen = set() # index set of frozen centroids 121 | for useless in useless_sorted: 122 | # ensure that current centroid is not frozen 123 | if useless not in Frozen: 124 | # register current centroid for removal 125 | Dminus.add(useless) 126 | nn=nearest_neighbor[useless] 127 | if len(Frozen) + m < len(C): 128 | # freeze nearest neighbor centroid 129 | Frozen.add(nn) 130 | if len(Dminus) == m: 131 | # found m centroids to remove 132 | break 133 | # return reduced codebook 134 | return C[list(set(range(len(C)))-Dminus)] 135 | 136 | --------------------------------------------------------------------------------