├── .gitignore ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── active_semi_clustering ├── __init__.py ├── active │ ├── __init__.py │ └── pairwise_constraints │ │ ├── __init__.py │ │ ├── example_oracle.py │ │ ├── explore_consolidate.py │ │ ├── helpers.py │ │ ├── min_max.py │ │ ├── npu.py │ │ └── random.py ├── exceptions.py ├── farthest_first_traversal.py └── semi_supervised │ ├── __init__.py │ ├── labeled_data │ ├── __init__.py │ ├── constrainedkmeans.py │ ├── kmeans.py │ └── seededkmeans.py │ └── pairwise_constraints │ ├── __init__.py │ ├── constraints.py │ ├── copkmeans.py │ ├── mkmeans.py │ ├── mpckmeans.py │ ├── mpckmeansmf.py │ ├── pckmeans.py │ └── rcakmeans.py ├── examples └── Active-Semi-Supervised-Clustering.ipynb └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | active-semi-supervised-clustering = {editable = true, path = "."} 8 | 9 | [dev-packages] 10 | jupyter = "*" 11 | scikit-learn = "*" 12 | setuptools = "*" 13 | wheel = "*" 14 | twine = "*" 15 | 16 | [requires] 17 | python_version = "3" 18 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "2973e7ae8a939a5aaed6fafcf3846461ac5263fdc33fb726af4a815ea3019e91" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "active-semi-supervised-clustering": { 20 | "editable": true, 21 | "path": "." 22 | }, 23 | "metric-learn": { 24 | "hashes": [ 25 | "sha256:3f3ccd61b6fd09ef780becab1f56a31c434d1d4ae9fc8b6386540ed91a0ba917", 26 | "sha256:697fa55bc11f97a36835cf70a7833b93bb5481a3468f503fb4da22bf0137b400" 27 | ], 28 | "markers": "python_version >= '2.7' and python_version != '3.0.*' and python_version != '3.3.*' and python_version != '3.2.*' and python_version != '3.1.*'", 29 | "version": "==0.4.0" 30 | }, 31 | "numpy": { 32 | "hashes": [ 33 | "sha256:1c362ad12dd09a43b348bb28dd2295dd9cdf77f41f0f45965e04ba97f525b864", 34 | "sha256:2156a06bd407918df4ac0122df6497a9c137432118f585e5b17d543e593d1587", 35 | "sha256:24e4149c38489b51fc774b1e1faa9103e82f73344d7a00ba66f6845ab4769f3f", 36 | "sha256:340ec1697d9bb3a9c464028af7a54245298502e91178bddb4c37626d36e197b7", 37 | "sha256:35db8d419345caa4eeaa65cd63f34a15208acd87530a30f0bc25fc84f55c8c80", 38 | "sha256:361370e9b7f5e44c41eee29f2bb5cb3b755abb4b038bce6d6cbe08db7ff9cb74", 39 | "sha256:36e8dcd1813ca92ce7e4299120cee6c03adad33d89b54862c1b1a100443ac399", 40 | "sha256:378378973546ecc1dfaf9e24c160d683dd04df871ecd2dcc86ce658ca20f92c0", 41 | "sha256:419e6faee16097124ee627ed31572c7e80a1070efa25260b78097cca240e219a", 42 | "sha256:4287104c24e6a09b9b418761a1e7b1bbde65105f110690ca46a23600a3c606b8", 43 | "sha256:549f3e9778b148a47f4fb4682955ed88057eb627c9fe5467f33507c536deda9d", 44 | "sha256:5e359e9c531075220785603e5966eef20ccae9b3b6b8a06fdfb66c084361ce92", 45 | "sha256:5ee7f3dbbdba0da75dec7e94bd7a2b10fe57a83e1b38e678200a6ad8e7b14fdc", 46 | "sha256:62d55e96ec7b117d3d5e618c15efcf769e70a6effaee5842857b64fb4883887a", 47 | "sha256:719b6789acb2bc86ea9b33a701d7c43dc2fc56d95107fd3c5b0a8230164d4dfb", 48 | "sha256:7a70f2b60d48828cba94a54a8776b61a9c2657a803d47f5785f8062e3a9c7c55", 49 | "sha256:7b9e37f194f8bcdca8e9e6af92e2cbad79e360542effc2dd6b98d63955d8d8a3", 50 | "sha256:83b8fc18261b70f45bece2d392537c93dc81eb6c539a16c9ac994c47fc79f09a", 51 | "sha256:9473ad28375710ab18378e72b59422399b27e957e9339c413bf00793b4b12df0", 52 | "sha256:95b085b253080e5d09f7826f5e27dce067bae813a132023a77b739614a29de6e", 53 | "sha256:98b86c62c08c2e5dc98a9c856d4a95329d11b1c6058cb9b5191d5ea6891acd09", 54 | "sha256:a3bd01d6d3ed3d7c06d7f9979ba5d68281f15383fafd53b81aa44b9191047cf8", 55 | "sha256:c81a6afc1d2531a9ada50b58f8c36197f8418ef3d0611d4c1d7af93fdcda764f", 56 | "sha256:ce75ed495a746e3e78cfa22a77096b3bff2eda995616cb7a542047f233091268", 57 | "sha256:dae8618c0bcbfcf6cf91350f8abcdd84158323711566a8c5892b5c7f832af76f", 58 | "sha256:df0b02c6705c5d1c25cc35c7b5d6b6f9b3b30833f9d178843397ae55ecc2eebb", 59 | "sha256:e3660744cda0d94b90141cdd0db9308b958a372cfeee8d7188fdf5ad9108ea82", 60 | "sha256:f2362d0ca3e16c37782c1054d7972b8ad2729169567e3f0f4e5dd3cdf85f188e" 61 | ], 62 | "markers": "python_version >= '2.7' and python_version != '3.0.*' and python_version != '3.3.*' and python_version != '3.2.*' and python_version != '3.1.*'", 63 | "version": "==1.15.1" 64 | }, 65 | "scikit-learn": { 66 | "hashes": [ 67 | "sha256:0a718b5ffbd5053fb3f9e1a2e20b7c4f256dd8035e246b907d3117d20bac0260", 68 | "sha256:1725540b754a9967778e9385e1ee2c8db50d5ab70ed835c9f5e36002ffabc169", 69 | "sha256:3e3ce307d7c5c5811658ba8686b24b571a8244eaafe707665ad601f400d5ce98", 70 | "sha256:42ad71502237c9fe300ecf157f5a394df717789a2dde541dd7034b539c70bdcc", 71 | "sha256:42cba716db197e0d1670e2fc13c4cc4a86d5c5358120ccfee6ec427b154e74ff", 72 | "sha256:47b4090b7686642e41176becb7c42ef3cc665d7ee0db5e7ea5d307ec9779327e", 73 | "sha256:51d99a08c8bf689cf60c9d8dca6e3d3e5f6d762def85ad735dcea11fb528a89b", 74 | "sha256:5f7577fbb2399a4712e96cf0e786638168940a876c33735a1b5d5a86ba4b1370", 75 | "sha256:66bfc2b6b15db1725d03ea657ec9184ff09dcbf1ecd834ef85f2edc2c9cbba97", 76 | "sha256:69a34d389d9ca4687ad00af4e11d53686771f484c37366f68617ef656bab16ab", 77 | "sha256:75297f3dd6685f01555f1bb75846995d45650af417280b69c81bf11b6987aed5", 78 | "sha256:9ebb38ab1d0ee143982aed561811903ac6c1abb512ae2b9019b3b65bde63ffb9", 79 | "sha256:a402c1484fe65df42d5dbc22a58e0695fe3afe2b0b229aee2a09c6d60ba8e5c2", 80 | "sha256:aad6b9aac1617bd7efa0450643888bbd3410679a94bc8680d9863825686ef369", 81 | "sha256:ad4db28d3dc16c01df75ed6efb72524537de3839a5d179fcf94094359fc72ec5", 82 | "sha256:b276739a5f863ccacb61999a3067d0895ee291c95502929b2ae56ea1f882e888", 83 | "sha256:b3dc88c4d2bcb26ffc5afe16d053ae28317d7d1de083651defcd5453a04f1563", 84 | "sha256:b3e4681253e95da5aa5c231889a32b084fd997962bf8beda6f796bf422f734b2", 85 | "sha256:c3d852d49d6c1710089d4513702099fa6f8e1aebfedf222319d80c47b0a195f8", 86 | "sha256:c6612e7e43988b8b5e1957150449493a55f9c059de641083df7a964f86f2d1e7", 87 | "sha256:c69e5c6051366a6ac9600d730276db939b1a205e42504ec0b8371f154b0058db", 88 | "sha256:ce121baa8e85ec27c3065281657dcd78adaab7dcb046c7fe96ad4e5a9dcb6610", 89 | "sha256:ed2a9a9bea6ec443b7effe5695c9c168b7bf9a67df6d880729760feda871b6a3", 90 | "sha256:efd842d70b87e3ef3429c3149840b9189d4441ca951ab0cec62c94a964e219d9", 91 | "sha256:f1428af5c381f6eef30ffbc7e047b7c713d4efa5d7bf5e57b62b3fc8d387044b", 92 | "sha256:f6c7bf8cd4de1640b760b47f4d28deb26dbbf9acbe0194cdff54a898e190d872", 93 | "sha256:f8329ac2160ad8bbbac6a507374685ceca3f24ca427fa9ee61a501280e1972d9", 94 | "sha256:fefba2a43b92f8393366093b60efbe984a72a2b41cce16b4002005e4104ef938" 95 | ], 96 | "version": "==0.19.2" 97 | }, 98 | "scipy": { 99 | "hashes": [ 100 | "sha256:0611ee97296265af4a21164a5323f8c1b4e8e15c582d3dfa7610825900136bb7", 101 | "sha256:08237eda23fd8e4e54838258b124f1cd141379a5f281b0a234ca99b38918c07a", 102 | "sha256:0e645dbfc03f279e1946cf07c9c754c2a1859cb4a41c5f70b25f6b3a586b6dbd", 103 | "sha256:0e9bb7efe5f051ea7212555b290e784b82f21ffd0f655405ac4f87e288b730b3", 104 | "sha256:108c16640849e5827e7d51023efb3bd79244098c3f21e4897a1007720cb7ce37", 105 | "sha256:340ef70f5b0f4e2b4b43c8c8061165911bc6b2ad16f8de85d9774545e2c47463", 106 | "sha256:3ad73dfc6f82e494195144bd3a129c7241e761179b7cb5c07b9a0ede99c686f3", 107 | "sha256:3b243c77a822cd034dad53058d7c2abf80062aa6f4a32e9799c95d6391558631", 108 | "sha256:404a00314e85eca9d46b80929571b938e97a143b4f2ddc2b2b3c91a4c4ead9c5", 109 | "sha256:423b3ff76957d29d1cce1bc0d62ebaf9a3fdfaf62344e3fdec14619bb7b5ad3a", 110 | "sha256:42d9149a2fff7affdd352d157fa5717033767857c11bd55aa4a519a44343dfef", 111 | "sha256:625f25a6b7d795e8830cb70439453c9f163e6870e710ec99eba5722775b318f3", 112 | "sha256:698c6409da58686f2df3d6f815491fd5b4c2de6817a45379517c92366eea208f", 113 | "sha256:729f8f8363d32cebcb946de278324ab43d28096f36593be6281ca1ee86ce6559", 114 | "sha256:8190770146a4c8ed5d330d5b5ad1c76251c63349d25c96b3094875b930c44692", 115 | "sha256:878352408424dffaa695ffedf2f9f92844e116686923ed9aa8626fc30d32cfd1", 116 | "sha256:8b984f0821577d889f3c7ca8445564175fb4ac7c7f9659b7c60bef95b2b70e76", 117 | "sha256:8f841bbc21d3dad2111a94c490fb0a591b8612ffea86b8e5571746ae76a3deac", 118 | "sha256:c22b27371b3866c92796e5d7907e914f0e58a36d3222c5d436ddd3f0e354227a", 119 | "sha256:d0cdd5658b49a722783b8b4f61a6f1f9c75042d0e29a30ccb6cacc9b25f6d9e2", 120 | "sha256:d40dc7f494b06dcee0d303e51a00451b2da6119acbeaccf8369f2d29e28917ac", 121 | "sha256:d8491d4784aceb1f100ddb8e31239c54e4afab8d607928a9f7ef2469ec35ae01", 122 | "sha256:dfc5080c38dde3f43d8fbb9c0539a7839683475226cf83e4b24363b227dfe552", 123 | "sha256:e24e22c8d98d3c704bb3410bce9b69e122a8de487ad3dbfe9985d154e5c03a40", 124 | "sha256:e7a01e53163818d56eabddcafdc2090e9daba178aad05516b20c6591c4811020", 125 | "sha256:ee677635393414930541a096fc8e61634304bb0153e4e02b75685b11eba14cae", 126 | "sha256:f0521af1b722265d824d6ad055acfe9bd3341765735c44b5a4d0069e189a0f40", 127 | "sha256:f25c281f12c0da726c6ed00535ca5d1622ec755c30a3f8eafef26cf43fede694" 128 | ], 129 | "markers": "python_version >= '2.7' and python_version != '3.0.*' and python_version != '3.3.*' and python_version != '3.2.*' and python_version != '3.1.*'", 130 | "version": "==1.1.0" 131 | }, 132 | "six": { 133 | "hashes": [ 134 | "sha256:70e8a77beed4562e7f14fe23a786b54f6296e34344c23bc42f07b15018ff98e9", 135 | "sha256:832dc0e10feb1aa2c68dcc57dbb658f1c7e65b9b61af69048abc87a2db00a0eb" 136 | ], 137 | "version": "==1.11.0" 138 | } 139 | }, 140 | "develop": { 141 | "appnope": { 142 | "hashes": [ 143 | "sha256:5b26757dc6f79a3b7dc9fab95359328d5747fcb2409d331ea66d0272b90ab2a0", 144 | "sha256:8b995ffe925347a2138d7ac0fe77155e4311a0ea6d6da4f5128fe4b3cbe5ed71" 145 | ], 146 | "markers": "sys_platform == 'darwin'", 147 | "version": "==0.1.0" 148 | }, 149 | "backcall": { 150 | "hashes": [ 151 | "sha256:38ecd85be2c1e78f77fd91700c76e14667dc21e2713b63876c0eb901196e01e4", 152 | "sha256:bbbf4b1e5cd2bdb08f915895b51081c041bac22394fdfcfdfbe9f14b77c08bf2" 153 | ], 154 | "version": "==0.1.0" 155 | }, 156 | "bleach": { 157 | "hashes": [ 158 | "sha256:0ee95f6167129859c5dce9b1ca291ebdb5d8cd7e382ca0e237dfd0dad63f63d8", 159 | "sha256:24754b9a7d530bf30ce7cbc805bc6cce785660b4a10ff3a43633728438c105ab" 160 | ], 161 | "version": "==2.1.4" 162 | }, 163 | "certifi": { 164 | "hashes": [ 165 | "sha256:376690d6f16d32f9d1fe8932551d80b23e9d393a8578c5633a2ed39a64861638", 166 | "sha256:456048c7e371c089d0a77a5212fb37a2c2dce1e24146e3b7e0261736aaeaa22a" 167 | ], 168 | "version": "==2018.8.24" 169 | }, 170 | "chardet": { 171 | "hashes": [ 172 | "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae", 173 | "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691" 174 | ], 175 | "version": "==3.0.4" 176 | }, 177 | "decorator": { 178 | "hashes": [ 179 | "sha256:2c51dff8ef3c447388fe5e4453d24a2bf128d3a4c32af3fabef1f01c6851ab82", 180 | "sha256:c39efa13fbdeb4506c476c9b3babf6a718da943dab7811c206005a4a956c080c" 181 | ], 182 | "version": "==4.3.0" 183 | }, 184 | "defusedxml": { 185 | "hashes": [ 186 | "sha256:24d7f2f94f7f3cb6061acb215685e5125fbcdc40a857eff9de22518820b0a4f4", 187 | "sha256:702a91ade2968a82beb0db1e0766a6a273f33d4616a6ce8cde475d8e09853b20" 188 | ], 189 | "version": "==0.5.0" 190 | }, 191 | "entrypoints": { 192 | "hashes": [ 193 | "sha256:10ad569bb245e7e2ba425285b9fa3e8178a0dc92fc53b1e1c553805e15a8825b", 194 | "sha256:d2d587dde06f99545fb13a383d2cd336a8ff1f359c5839ce3a64c917d10c029f" 195 | ], 196 | "markers": "python_version >= '2.7'", 197 | "version": "==0.2.3" 198 | }, 199 | "html5lib": { 200 | "hashes": [ 201 | "sha256:20b159aa3badc9d5ee8f5c647e5efd02ed2a66ab8d354930bd9ff139fc1dc0a3", 202 | "sha256:66cb0dcfdbbc4f9c3ba1a63fdb511ffdbd4f513b2b6d81b80cd26ce6b3fb3736" 203 | ], 204 | "version": "==1.0.1" 205 | }, 206 | "idna": { 207 | "hashes": [ 208 | "sha256:156a6814fb5ac1fc6850fb002e0852d56c0c8d2531923a51032d1b70760e186e", 209 | "sha256:684a38a6f903c1d71d6d5fac066b58d7768af4de2b832e426ec79c30daa94a16" 210 | ], 211 | "version": "==2.7" 212 | }, 213 | "ipykernel": { 214 | "hashes": [ 215 | "sha256:00d88b7e628e4e893359119b894451611214bce09776a3bf8248fe42cb48ada6", 216 | "sha256:a706b975376efef98b70e10cd167ab9506cf08a689d689a3c7daf344c15040f6", 217 | "sha256:c5a498c70f7765c34f3397cf943b069057f5bef4e0218e4cfbb733e9f38fa5fa" 218 | ], 219 | "version": "==4.9.0" 220 | }, 221 | "ipython": { 222 | "hashes": [ 223 | "sha256:007dcd929c14631f83daff35df0147ea51d1af420da303fd078343878bd5fb62", 224 | "sha256:b0f2ef9eada4a68ef63ee10b6dde4f35c840035c50fd24265f8052c98947d5a4" 225 | ], 226 | "markers": "python_version >= '3.3'", 227 | "version": "==6.5.0" 228 | }, 229 | "ipython-genutils": { 230 | "hashes": [ 231 | "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8", 232 | "sha256:eb2e116e75ecef9d4d228fdc66af54269afa26ab4463042e33785b887c628ba8" 233 | ], 234 | "version": "==0.2.0" 235 | }, 236 | "ipywidgets": { 237 | "hashes": [ 238 | "sha256:0f2b5cde9f272cb49d52f3f0889fdd1a7ae1e74f37b48dac35a83152780d2b7b", 239 | "sha256:a3e224f430163f767047ab9a042fc55adbcab0c24bbe6cf9f306c4f89fdf0ba3" 240 | ], 241 | "version": "==7.4.2" 242 | }, 243 | "jedi": { 244 | "hashes": [ 245 | "sha256:b409ed0f6913a701ed474a614a3bb46e6953639033e31f769ca7581da5bd1ec1", 246 | "sha256:c254b135fb39ad76e78d4d8f92765ebc9bf92cbc76f49e97ade1d5f5121e1f6f" 247 | ], 248 | "version": "==0.12.1" 249 | }, 250 | "jinja2": { 251 | "hashes": [ 252 | "sha256:74c935a1b8bb9a3947c50a54766a969d4846290e1e788ea44c1392163723c3bd", 253 | "sha256:f84be1bb0040caca4cea721fcbbbbd61f9be9464ca236387158b0feea01914a4" 254 | ], 255 | "version": "==2.10" 256 | }, 257 | "jsonschema": { 258 | "hashes": [ 259 | "sha256:000e68abd33c972a5248544925a0cae7d1125f9bf6c58280d37546b946769a08", 260 | "sha256:6ff5f3180870836cae40f06fa10419f557208175f13ad7bc26caa77beb1f6e02" 261 | ], 262 | "version": "==2.6.0" 263 | }, 264 | "jupyter": { 265 | "hashes": [ 266 | "sha256:3e1f86076bbb7c8c207829390305a2b1fe836d471ed54be66a3b8c41e7f46cc7", 267 | "sha256:5b290f93b98ffbc21c0c7e749f054b3267782166d72fa5e3ed1ed4eaf34a2b78", 268 | "sha256:d9dc4b3318f310e34c82951ea5d6683f67bed7def4b259fafbfe4f1beb1d8e5f" 269 | ], 270 | "index": "pypi", 271 | "version": "==1.0.0" 272 | }, 273 | "jupyter-client": { 274 | "hashes": [ 275 | "sha256:27befcf0446b01e29853014d6a902dd101ad7d7f94e2252b1adca17c3466b761", 276 | "sha256:59e6d791e22a8002ad0e80b78c6fd6deecab4f9e1b1aa1a22f4213de271b29ea" 277 | ], 278 | "version": "==5.2.3" 279 | }, 280 | "jupyter-console": { 281 | "hashes": [ 282 | "sha256:3f928b817fc82cda95e431eb4c2b5eb21be5c483c2b43f424761a966bb808094", 283 | "sha256:545dedd3aaaa355148093c5609f0229aeb121b4852995c2accfa64fe3e0e55cd" 284 | ], 285 | "version": "==5.2.0" 286 | }, 287 | "jupyter-core": { 288 | "hashes": [ 289 | "sha256:927d713ffa616ea11972534411544589976b2493fc7e09ad946e010aa7eb9970", 290 | "sha256:ba70754aa680300306c699790128f6fbd8c306ee5927976cbe48adacf240c0b7" 291 | ], 292 | "version": "==4.4.0" 293 | }, 294 | "markupsafe": { 295 | "hashes": [ 296 | "sha256:a6be69091dac236ea9c6bc7d012beab42010fa914c459791d627dad4910eb665" 297 | ], 298 | "version": "==1.0" 299 | }, 300 | "mistune": { 301 | "hashes": [ 302 | "sha256:b4c512ce2fc99e5a62eb95a4aba4b73e5f90264115c40b70a21e1f7d4e0eac91", 303 | "sha256:bc10c33bfdcaa4e749b779f62f60d6e12f8215c46a292d05e486b869ae306619" 304 | ], 305 | "version": "==0.8.3" 306 | }, 307 | "nbconvert": { 308 | "hashes": [ 309 | "sha256:08d21cf4203fabafd0d09bbd63f06131b411db8ebeede34b0fd4be4548351779", 310 | "sha256:a8a2749f972592aa9250db975304af6b7337f32337e523a2c995cc9e12c07807" 311 | ], 312 | "version": "==5.4.0" 313 | }, 314 | "nbformat": { 315 | "hashes": [ 316 | "sha256:b9a0dbdbd45bb034f4f8893cafd6f652ea08c8c1674ba83f2dc55d3955743b0b", 317 | "sha256:f7494ef0df60766b7cabe0a3651556345a963b74dbc16bc7c18479041170d402" 318 | ], 319 | "version": "==4.4.0" 320 | }, 321 | "notebook": { 322 | "hashes": [ 323 | "sha256:66dd59e76e755584ae9450eb015c39f55d4bb1d8ec68f2c694d2b3cba7bf5c7e", 324 | "sha256:e2c8e931cc19db4f8c63e6a396efbc13a228b2cb5b2919df011b946f28239a08" 325 | ], 326 | "version": "==5.6.0" 327 | }, 328 | "pandocfilters": { 329 | "hashes": [ 330 | "sha256:b3dd70e169bb5449e6bc6ff96aea89c5eea8c5f6ab5e207fc2f521a2cf4a0da9" 331 | ], 332 | "version": "==1.4.2" 333 | }, 334 | "parso": { 335 | "hashes": [ 336 | "sha256:35704a43a3c113cce4de228ddb39aab374b8004f4f2407d070b6a2ca784ce8a2", 337 | "sha256:895c63e93b94ac1e1690f5fdd40b65f07c8171e3e53cbd7793b5b96c0e0a7f24" 338 | ], 339 | "version": "==0.3.1" 340 | }, 341 | "pexpect": { 342 | "hashes": [ 343 | "sha256:2a8e88259839571d1251d278476f3eec5db26deb73a70be5ed5dc5435e418aba", 344 | "sha256:3fbd41d4caf27fa4a377bfd16fef87271099463e6fa73e92a52f92dfee5d425b" 345 | ], 346 | "markers": "sys_platform != 'win32'", 347 | "version": "==4.6.0" 348 | }, 349 | "pickleshare": { 350 | "hashes": [ 351 | "sha256:84a9257227dfdd6fe1b4be1319096c20eb85ff1e82c7932f36efccfe1b09737b", 352 | "sha256:c9a2541f25aeabc070f12f452e1f2a8eae2abd51e1cd19e8430402bdf4c1d8b5" 353 | ], 354 | "version": "==0.7.4" 355 | }, 356 | "pkginfo": { 357 | "hashes": [ 358 | "sha256:5878d542a4b3f237e359926384f1dde4e099c9f5525d236b1840cf704fa8d474", 359 | "sha256:a39076cb3eb34c333a0dd390b568e9e1e881c7bf2cc0aee12120636816f55aee" 360 | ], 361 | "version": "==1.4.2" 362 | }, 363 | "prometheus-client": { 364 | "hashes": [ 365 | "sha256:17bc24c09431644f7c65d7bce9f4237252308070b6395d6d8e87767afe867e24" 366 | ], 367 | "version": "==0.3.1" 368 | }, 369 | "prompt-toolkit": { 370 | "hashes": [ 371 | "sha256:1df952620eccb399c53ebb359cc7d9a8d3a9538cb34c5a1344bdbeb29fbcc381", 372 | "sha256:3f473ae040ddaa52b52f97f6b4a493cfa9f5920c255a12dc56a7d34397a398a4", 373 | "sha256:858588f1983ca497f1cf4ffde01d978a3ea02b01c8a26a8bbc5cd2e66d816917" 374 | ], 375 | "version": "==1.0.15" 376 | }, 377 | "ptyprocess": { 378 | "hashes": [ 379 | "sha256:923f299cc5ad920c68f2bc0bc98b75b9f838b93b599941a6b63ddbc2476394c0", 380 | "sha256:d7cc528d76e76342423ca640335bd3633420dc1366f258cb31d05e865ef5ca1f" 381 | ], 382 | "markers": "os_name != 'nt'", 383 | "version": "==0.6.0" 384 | }, 385 | "pygments": { 386 | "hashes": [ 387 | "sha256:78f3f434bcc5d6ee09020f92ba487f95ba50f1e3ef83ae96b9d5ffa1bab25c5d", 388 | "sha256:dbae1046def0efb574852fab9e90209b23f556367b5a320c0bcb871c77c3e8cc" 389 | ], 390 | "version": "==2.2.0" 391 | }, 392 | "python-dateutil": { 393 | "hashes": [ 394 | "sha256:1adb80e7a782c12e52ef9a8182bebeb73f1d7e24e374397af06fb4956c8dc5c0", 395 | "sha256:e27001de32f627c22380a688bcc43ce83504a7bc5da472209b4c70f02829f0b8" 396 | ], 397 | "version": "==2.7.3" 398 | }, 399 | "pyzmq": { 400 | "hashes": [ 401 | "sha256:25a0715c8f69cf72f67cfe5a68a3f3ed391c67c063d2257bec0fe7fc2c7f08f8", 402 | "sha256:2bab63759632c6b9e0d5bf19cc63c3b01df267d660e0abcf230cf0afaa966349", 403 | "sha256:30ab49d99b24bf0908ebe1cdfa421720bfab6f93174e4883075b7ff38cc555ba", 404 | "sha256:32c7ca9fc547a91e3c26fc6080b6982e46e79819e706eb414dd78f635a65d946", 405 | "sha256:41219ae72b3cc86d97557fe5b1ef5d1adc1057292ec597b50050874a970a39cf", 406 | "sha256:4b8c48a9a13cea8f1f16622f9bd46127108af14cd26150461e3eab71e0de3e46", 407 | "sha256:55724997b4a929c0d01b43c95051318e26ddbae23565018e138ae2dc60187e59", 408 | "sha256:65f0a4afae59d4fc0aad54a917ab599162613a761b760ba167d66cc646ac3786", 409 | "sha256:6f88591a8b246f5c285ee6ce5c1bf4f6bd8464b7f090b1333a446b6240a68d40", 410 | "sha256:75022a4c60dcd8765bb9ca32f6de75a0ec83b0d96e0309dc479f4c7b21f26cb7", 411 | "sha256:76ea493bfab18dcb090d825f3662b5612e2def73dffc196d51a5194b0294a81d", 412 | "sha256:7b60c045b80709e4e3c085bab9b691e71761b44c2b42dbb047b8b498e7bc16b3", 413 | "sha256:8e6af2f736734aef8ed6f278f9f552ec7f37b1a6b98e59b887484a840757f67d", 414 | "sha256:9ac2298e486524331e26390eac14e4627effd3f8e001d4266ed9d8f1d2d31cce", 415 | "sha256:9ba650f493a9bc1f24feca1d90fce0e5dd41088a252ac9840131dfbdbf3815ca", 416 | "sha256:a02a4a385e394e46012dc83d2e8fd6523f039bb52997c1c34a2e0dd49ed839c1", 417 | "sha256:a3ceee84114d9f5711fa0f4db9c652af0e4636c89eabc9b7f03a3882569dd1ed", 418 | "sha256:a72b82ac1910f2cf61a49139f4974f994984475f771b0faa730839607eeedddf", 419 | "sha256:ab136ac51027e7c484c53138a0fab4a8a51e80d05162eb7b1585583bcfdbad27", 420 | "sha256:c095b224300bcac61e6c445e27f9046981b1ac20d891b2f1714da89d34c637c8", 421 | "sha256:c5cc52d16c06dc2521340d69adda78a8e1031705924e103c0eb8fc8af861d810", 422 | "sha256:d612e9833a89e8177f8c1dc68d7b4ff98d3186cd331acd616b01bbdab67d3a7b", 423 | "sha256:e828376a23c66c6fe90dcea24b4b72cd774f555a6ee94081670872918df87a19", 424 | "sha256:e9767c7ab2eb552796440168d5c6e23a99ecaade08dda16266d43ad461730192", 425 | "sha256:ebf8b800d42d217e4710d1582b0c8bff20cdcb4faad7c7213e52644034300924" 426 | ], 427 | "markers": "python_version >= '2.7' and python_version != '3.2*' and python_version != '3.0*' and python_version != '3.1*'", 428 | "version": "==17.1.2" 429 | }, 430 | "qtconsole": { 431 | "hashes": [ 432 | "sha256:298431d376d71a02eb1a04fe6e72dd4beb82b83423d58b17d532e0af838e62fa", 433 | "sha256:7870b19e6a6b0ab3acc09ee65463c0ca7568b3a01a6902d7c4e1ed2c4fc4e176" 434 | ], 435 | "version": "==4.4.1" 436 | }, 437 | "requests": { 438 | "hashes": [ 439 | "sha256:63b52e3c866428a224f97cab011de738c36aec0185aa91cfacd418b5d58911d1", 440 | "sha256:ec22d826a36ed72a7358ff3fe56cbd4ba69dd7a6718ffd450ff0e9df7a47ce6a" 441 | ], 442 | "version": "==2.19.1" 443 | }, 444 | "requests-toolbelt": { 445 | "hashes": [ 446 | "sha256:42c9c170abc2cacb78b8ab23ac957945c7716249206f90874651971a4acff237", 447 | "sha256:f6a531936c6fa4c6cfce1b9c10d5c4f498d16528d2a54a22ca00011205a187b5" 448 | ], 449 | "version": "==0.8.0" 450 | }, 451 | "scikit-learn": { 452 | "hashes": [ 453 | "sha256:0a718b5ffbd5053fb3f9e1a2e20b7c4f256dd8035e246b907d3117d20bac0260", 454 | "sha256:1725540b754a9967778e9385e1ee2c8db50d5ab70ed835c9f5e36002ffabc169", 455 | "sha256:3e3ce307d7c5c5811658ba8686b24b571a8244eaafe707665ad601f400d5ce98", 456 | "sha256:42ad71502237c9fe300ecf157f5a394df717789a2dde541dd7034b539c70bdcc", 457 | "sha256:42cba716db197e0d1670e2fc13c4cc4a86d5c5358120ccfee6ec427b154e74ff", 458 | "sha256:47b4090b7686642e41176becb7c42ef3cc665d7ee0db5e7ea5d307ec9779327e", 459 | "sha256:51d99a08c8bf689cf60c9d8dca6e3d3e5f6d762def85ad735dcea11fb528a89b", 460 | "sha256:5f7577fbb2399a4712e96cf0e786638168940a876c33735a1b5d5a86ba4b1370", 461 | "sha256:66bfc2b6b15db1725d03ea657ec9184ff09dcbf1ecd834ef85f2edc2c9cbba97", 462 | "sha256:69a34d389d9ca4687ad00af4e11d53686771f484c37366f68617ef656bab16ab", 463 | "sha256:75297f3dd6685f01555f1bb75846995d45650af417280b69c81bf11b6987aed5", 464 | "sha256:9ebb38ab1d0ee143982aed561811903ac6c1abb512ae2b9019b3b65bde63ffb9", 465 | "sha256:a402c1484fe65df42d5dbc22a58e0695fe3afe2b0b229aee2a09c6d60ba8e5c2", 466 | "sha256:aad6b9aac1617bd7efa0450643888bbd3410679a94bc8680d9863825686ef369", 467 | "sha256:ad4db28d3dc16c01df75ed6efb72524537de3839a5d179fcf94094359fc72ec5", 468 | "sha256:b276739a5f863ccacb61999a3067d0895ee291c95502929b2ae56ea1f882e888", 469 | "sha256:b3dc88c4d2bcb26ffc5afe16d053ae28317d7d1de083651defcd5453a04f1563", 470 | "sha256:b3e4681253e95da5aa5c231889a32b084fd997962bf8beda6f796bf422f734b2", 471 | "sha256:c3d852d49d6c1710089d4513702099fa6f8e1aebfedf222319d80c47b0a195f8", 472 | "sha256:c6612e7e43988b8b5e1957150449493a55f9c059de641083df7a964f86f2d1e7", 473 | "sha256:c69e5c6051366a6ac9600d730276db939b1a205e42504ec0b8371f154b0058db", 474 | "sha256:ce121baa8e85ec27c3065281657dcd78adaab7dcb046c7fe96ad4e5a9dcb6610", 475 | "sha256:ed2a9a9bea6ec443b7effe5695c9c168b7bf9a67df6d880729760feda871b6a3", 476 | "sha256:efd842d70b87e3ef3429c3149840b9189d4441ca951ab0cec62c94a964e219d9", 477 | "sha256:f1428af5c381f6eef30ffbc7e047b7c713d4efa5d7bf5e57b62b3fc8d387044b", 478 | "sha256:f6c7bf8cd4de1640b760b47f4d28deb26dbbf9acbe0194cdff54a898e190d872", 479 | "sha256:f8329ac2160ad8bbbac6a507374685ceca3f24ca427fa9ee61a501280e1972d9", 480 | "sha256:fefba2a43b92f8393366093b60efbe984a72a2b41cce16b4002005e4104ef938" 481 | ], 482 | "version": "==0.19.2" 483 | }, 484 | "send2trash": { 485 | "hashes": [ 486 | "sha256:60001cc07d707fe247c94f74ca6ac0d3255aabcb930529690897ca2a39db28b2", 487 | "sha256:f1691922577b6fa12821234aeb57599d887c4900b9ca537948d2dac34aea888b" 488 | ], 489 | "version": "==1.5.0" 490 | }, 491 | "simplegeneric": { 492 | "hashes": [ 493 | "sha256:dc972e06094b9af5b855b3df4a646395e43d1c9d0d39ed345b7393560d0b9173" 494 | ], 495 | "version": "==0.8.1" 496 | }, 497 | "six": { 498 | "hashes": [ 499 | "sha256:70e8a77beed4562e7f14fe23a786b54f6296e34344c23bc42f07b15018ff98e9", 500 | "sha256:832dc0e10feb1aa2c68dcc57dbb658f1c7e65b9b61af69048abc87a2db00a0eb" 501 | ], 502 | "version": "==1.11.0" 503 | }, 504 | "terminado": { 505 | "hashes": [ 506 | "sha256:55abf9ade563b8f9be1f34e4233c7b7bde726059947a593322e8a553cc4c067a", 507 | "sha256:65011551baff97f5414c67018e908110693143cfbaeb16831b743fe7cad8b927" 508 | ], 509 | "markers": "python_version != '3.3.*' and python_version >= '2.7' and python_version != '3.0.*' and python_version != '3.2.*' and python_version != '3.1.*'", 510 | "version": "==0.8.1" 511 | }, 512 | "testpath": { 513 | "hashes": [ 514 | "sha256:039fa6a6c9fd3488f8336d23aebbfead5fa602c4a47d49d83845f55a595ec1b4", 515 | "sha256:0d5337839c788da5900df70f8e01015aec141aa3fe7936cb0d0a2953f7ac7609" 516 | ], 517 | "version": "==0.3.1" 518 | }, 519 | "tornado": { 520 | "hashes": [ 521 | "sha256:0662d28b1ca9f67108c7e3b77afabfb9c7e87bde174fbda78186ecedc2499a9d", 522 | "sha256:4e5158d97583502a7e2739951553cbd88a72076f152b4b11b64b9a10c4c49409", 523 | "sha256:732e836008c708de2e89a31cb2fa6c0e5a70cb60492bee6f1ea1047500feaf7f", 524 | "sha256:8154ec22c450df4e06b35f131adc4f2f3a12ec85981a203301d310abf580500f", 525 | "sha256:8e9d728c4579682e837c92fdd98036bd5cdefa1da2aaf6acf26947e6dd0c01c5", 526 | "sha256:d4b3e5329f572f055b587efc57d29bd051589fb5a43ec8898c77a47ec2fa2bbb", 527 | "sha256:e5f2585afccbff22390cddac29849df463b252b711aa2ce7c5f3f342a5b3b444" 528 | ], 529 | "markers": "python_version != '3.3.*' and python_version >= '2.7' and python_version != '3.0.*' and python_version != '3.2.*' and python_version != '3.1.*'", 530 | "version": "==5.1.1" 531 | }, 532 | "tqdm": { 533 | "hashes": [ 534 | "sha256:18f1818ce951aeb9ea162ae1098b43f583f7d057b34d706f66939353d1208889", 535 | "sha256:df02c0650160986bac0218bb07952245fc6960d23654648b5d5526ad5a4128c9" 536 | ], 537 | "markers": "python_version != '3.0.*' and python_version != '3.1.*' and python_version >= '2.6'", 538 | "version": "==4.26.0" 539 | }, 540 | "traitlets": { 541 | "hashes": [ 542 | "sha256:9c4bd2d267b7153df9152698efb1050a5d84982d3384a37b2c1f7723ba3e7835", 543 | "sha256:c6cb5e6f57c5a9bdaa40fa71ce7b4af30298fbab9ece9815b5d995ab6217c7d9" 544 | ], 545 | "version": "==4.3.2" 546 | }, 547 | "twine": { 548 | "hashes": [ 549 | "sha256:08eb132bbaec40c6d25b358f546ec1dc96ebd2638a86eea68769d9e67fe2b129", 550 | "sha256:2fd9a4d9ff0bcacf41fdc40c8cb0cfaef1f1859457c9653fd1b92237cc4e9f25" 551 | ], 552 | "index": "pypi", 553 | "version": "==1.11.0" 554 | }, 555 | "urllib3": { 556 | "hashes": [ 557 | "sha256:a68ac5e15e76e7e5dd2b8f94007233e01effe3e50e8daddf69acfd81cb686baf", 558 | "sha256:b5725a0bd4ba422ab0e66e89e030c806576753ea3ee08554382c14e685d117b5" 559 | ], 560 | "markers": "python_version != '3.3.*' and python_version < '4' and python_version != '3.1.*' and python_version != '3.0.*' and python_version != '3.2.*' and python_version >= '2.6'", 561 | "version": "==1.23" 562 | }, 563 | "wcwidth": { 564 | "hashes": [ 565 | "sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e", 566 | "sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c" 567 | ], 568 | "version": "==0.1.7" 569 | }, 570 | "webencodings": { 571 | "hashes": [ 572 | "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", 573 | "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923" 574 | ], 575 | "version": "==0.5.1" 576 | }, 577 | "wheel": { 578 | "hashes": [ 579 | "sha256:0a2e54558a0628f2145d2fc822137e322412115173e8a2ddbe1c9024338ae83c", 580 | "sha256:80044e51ec5bbf6c894ba0bc48d26a8c20a9ba629f4ca19ea26ecfcf87685f5f" 581 | ], 582 | "index": "pypi", 583 | "version": "==0.31.1" 584 | }, 585 | "widgetsnbextension": { 586 | "hashes": [ 587 | "sha256:14b2c65f9940c9a7d3b70adbe713dbd38b5ec69724eebaba034d1036cf3d4740", 588 | "sha256:fa618be8435447a017fd1bf2c7ae922d0428056cfc7449f7a8641edf76b48265" 589 | ], 590 | "version": "==3.4.2" 591 | } 592 | } 593 | } 594 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # active-semi-supervised-clustering 2 | 3 | Active semi-supervised clustering algorithms for scikit-learn. 4 | 5 | ## Algorithms 6 | 7 | ### Semi-supervised clustering 8 | 9 | * Seeded-KMeans 10 | * Constrainted-KMeans 11 | * COP-KMeans 12 | * Pairwise constrained K-Means (PCK-Means) 13 | * Metric K-Means (MK-Means) 14 | * Metric pairwise constrained K-Means (MPCK-Means) 15 | 16 | ### Active learning of pairwise clustering 17 | 18 | * Explore & Consolidate 19 | * Min-max 20 | * Normalized point-based uncertainty (NPU) method 21 | 22 | ## Installation 23 | 24 | ``` 25 | pip install active-semi-supervised-clustering 26 | ``` 27 | 28 | ## Usage 29 | 30 | ```python 31 | from sklearn import datasets, metrics 32 | from active_semi_clustering.semi_supervised.pairwise_constraints import PCKMeans 33 | from active_semi_clustering.active.pairwise_constraints import ExampleOracle, ExploreConsolidate, MinMax 34 | ``` 35 | 36 | ```python 37 | X, y = datasets.load_iris(return_X_y=True) 38 | ``` 39 | 40 | First, obtain some pairwise constraints from an oracle. 41 | 42 | ```python 43 | # TODO implement your own oracle that will, for example, query a domain expert via GUI or CLI 44 | oracle = ExampleOracle(y, max_queries_cnt=10) 45 | 46 | active_learner = MinMax(n_clusters=3) 47 | active_learner.fit(X, oracle=oracle) 48 | pairwise_constraints = active_learner.pairwise_constraints_ 49 | ``` 50 | 51 | Then, use the constraints to do the clustering. 52 | 53 | ```python 54 | clusterer = PCKMeans(n_clusters=3) 55 | clusterer.fit(X, ml=pairwise_constraints[0], cl=pairwise_constraints[1]) 56 | ``` 57 | 58 | Evaluate the clustering using Adjusted Rand Score. 59 | 60 | ```python 61 | metrics.adjusted_rand_score(y, clusterer.labels_) 62 | ``` 63 | -------------------------------------------------------------------------------- /active_semi_clustering/__init__.py: -------------------------------------------------------------------------------- 1 | from .semi_supervised.labeled_data import KMeans, SeededKMeans, ConstrainedKMeans 2 | from .semi_supervised.pairwise_constraints import COPKMeans, PCKMeans, MPCKMeans, MPCKMeansMF, MKMeans, RCAKMeans -------------------------------------------------------------------------------- /active_semi_clustering/active/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamole-ai/active-semi-supervised-clustering/0dcab86ea22cd66ed7ea64234efdff1c6aca7981/active_semi_clustering/active/__init__.py -------------------------------------------------------------------------------- /active_semi_clustering/active/pairwise_constraints/__init__.py: -------------------------------------------------------------------------------- 1 | from .explore_consolidate import ExploreConsolidate 2 | from .min_max import MinMax 3 | from .npu import NPU 4 | from .example_oracle import ExampleOracle -------------------------------------------------------------------------------- /active_semi_clustering/active/pairwise_constraints/example_oracle.py: -------------------------------------------------------------------------------- 1 | class MaximumQueriesExceeded(Exception): 2 | pass 3 | 4 | 5 | class ExampleOracle: 6 | def __init__(self, labels, max_queries_cnt=20): 7 | self.labels = labels 8 | self.queries_cnt = 0 9 | self.max_queries_cnt = max_queries_cnt 10 | 11 | def query(self, i, j): 12 | "Query the oracle to find out whether i and j should be must-linked" 13 | if self.queries_cnt < self.max_queries_cnt: 14 | self.queries_cnt += 1 15 | return self.labels[i] == self.labels[j] 16 | else: 17 | raise MaximumQueriesExceeded 18 | -------------------------------------------------------------------------------- /active_semi_clustering/active/pairwise_constraints/explore_consolidate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .helpers import get_constraints_from_neighborhoods 4 | from .example_oracle import MaximumQueriesExceeded 5 | 6 | 7 | class ExploreConsolidate: 8 | def __init__(self, n_clusters=3, **kwargs): 9 | self.n_clusters = n_clusters 10 | 11 | def fit(self, X, oracle=None): 12 | if oracle.max_queries_cnt <= 0: 13 | return [], [] 14 | 15 | neighborhoods = self._explore(X, self.n_clusters, oracle) 16 | neighborhoods = self._consolidate(neighborhoods, X, oracle) 17 | 18 | self.pairwise_constraints_ = get_constraints_from_neighborhoods(neighborhoods) 19 | 20 | return self 21 | 22 | def _explore(self, X, k, oracle): 23 | neighborhoods = [] 24 | traversed = [] 25 | n = X.shape[0] 26 | 27 | x = np.random.choice(n) 28 | neighborhoods.append([x]) 29 | traversed.append(x) 30 | 31 | try: 32 | while len(neighborhoods) < k: 33 | 34 | max_distance = 0 35 | farthest = None 36 | 37 | for i in range(n): 38 | if i not in traversed: 39 | distance = dist(i, traversed, X) 40 | if distance > max_distance: 41 | max_distance = distance 42 | farthest = i 43 | 44 | new_neighborhood = True 45 | for neighborhood in neighborhoods: 46 | if oracle.query(farthest, neighborhood[0]): 47 | neighborhood.append(farthest) 48 | new_neighborhood = False 49 | break 50 | 51 | if new_neighborhood: 52 | neighborhoods.append([farthest]) 53 | 54 | traversed.append(farthest) 55 | 56 | except MaximumQueriesExceeded: 57 | pass 58 | 59 | return neighborhoods 60 | 61 | def _consolidate(self, neighborhoods, X, oracle): 62 | n = X.shape[0] 63 | 64 | neighborhoods_union = set() 65 | for neighborhood in neighborhoods: 66 | for i in neighborhood: 67 | neighborhoods_union.add(i) 68 | 69 | remaining = set() 70 | for i in range(n): 71 | if i not in neighborhoods_union: 72 | remaining.add(i) 73 | 74 | while True: 75 | 76 | try: 77 | i = np.random.choice(list(remaining)) 78 | 79 | sorted_neighborhoods = sorted(neighborhoods, key=lambda neighborhood: dist(i, neighborhood, X)) 80 | 81 | for neighborhood in sorted_neighborhoods: 82 | if oracle.query(i, neighborhood[0]): 83 | neighborhood.append(i) 84 | break 85 | 86 | neighborhoods_union.add(i) 87 | remaining.remove(i) 88 | 89 | except MaximumQueriesExceeded: 90 | break 91 | 92 | return neighborhoods 93 | 94 | 95 | def dist(i, S, points): 96 | distances = np.array([np.sqrt(((points[i] - points[j]) ** 2).sum()) for j in S]) 97 | return distances.min() 98 | -------------------------------------------------------------------------------- /active_semi_clustering/active/pairwise_constraints/helpers.py: -------------------------------------------------------------------------------- 1 | def get_constraints_from_neighborhoods(neighborhoods): 2 | ml = [] 3 | 4 | for neighborhood in neighborhoods: 5 | for i in neighborhood: 6 | for j in neighborhood: 7 | if i != j: 8 | ml.append((i, j)) 9 | 10 | cl = [] 11 | for neighborhood in neighborhoods: 12 | for other_neighborhood in neighborhoods: 13 | if neighborhood != other_neighborhood: 14 | for i in neighborhood: 15 | for j in other_neighborhood: 16 | cl.append((i, j)) 17 | 18 | return ml, cl 19 | -------------------------------------------------------------------------------- /active_semi_clustering/active/pairwise_constraints/min_max.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .example_oracle import MaximumQueriesExceeded 4 | from .explore_consolidate import ExploreConsolidate 5 | 6 | 7 | class MinMax(ExploreConsolidate): 8 | def _consolidate(self, neighborhoods, X, oracle): 9 | n = X.shape[0] 10 | 11 | skeleton = set() 12 | for neighborhood in neighborhoods: 13 | for i in neighborhood: 14 | skeleton.add(i) 15 | 16 | remaining = set() 17 | for i in range(n): 18 | if i not in skeleton: 19 | remaining.add(i) 20 | 21 | distances = np.zeros((n, n)) 22 | for i in range(n): 23 | for j in range(n): 24 | distances[i, j] = np.sqrt(((X[i] - X[j]) ** 2).sum()) 25 | 26 | kernel_width = np.percentile(distances, 20) 27 | 28 | while True: 29 | try: 30 | max_similarities = np.full(n, fill_value=float('+inf')) 31 | for x_i in remaining: 32 | max_similarities[x_i] = np.max([similarity(X[x_i], X[x_j], kernel_width) for x_j in skeleton]) 33 | 34 | q_i = max_similarities.argmin() 35 | 36 | sorted_neighborhoods = reversed(sorted(neighborhoods, key=lambda neighborhood: np.max([similarity(X[q_i], X[n_i], kernel_width) for n_i in neighborhood]))) 37 | 38 | for neighborhood in sorted_neighborhoods: 39 | if oracle.query(q_i, neighborhood[0]): 40 | neighborhood.append(q_i) 41 | break 42 | 43 | skeleton.add(q_i) 44 | remaining.remove(q_i) 45 | 46 | except MaximumQueriesExceeded: 47 | break 48 | 49 | return neighborhoods 50 | 51 | 52 | def similarity(x, y, kernel_width): 53 | return np.exp(-((x - y) ** 2).sum() / (2 * (kernel_width ** 2))) 54 | -------------------------------------------------------------------------------- /active_semi_clustering/active/pairwise_constraints/npu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.ensemble import RandomForestClassifier 3 | 4 | from .example_oracle import MaximumQueriesExceeded 5 | from active_semi_clustering.exceptions import EmptyClustersException 6 | 7 | 8 | class NPU: 9 | def __init__(self, clusterer=None, **kwargs): 10 | self.clusterer = clusterer 11 | 12 | def fit(self, X, oracle=None): 13 | n = X.shape[0] 14 | ml, cl = [], [] 15 | neighborhoods = [] 16 | 17 | x_i = np.random.choice(list(range(n))) 18 | neighborhoods.append([x_i]) 19 | 20 | while True: 21 | try: 22 | while True: 23 | try: 24 | self.clusterer.fit(X, ml=ml, cl=cl) 25 | except EmptyClustersException: 26 | continue 27 | break 28 | 29 | x_i, p_i = self._most_informative(X, self.clusterer, neighborhoods) 30 | 31 | sorted_neighborhoods = list(zip(*reversed(sorted(zip(p_i, neighborhoods)))))[1] 32 | # print(x_i, neighborhoods, p_i, sorted_neighborhoods) 33 | 34 | must_link_found = False 35 | 36 | for neighborhood in sorted_neighborhoods: 37 | 38 | must_linked = oracle.query(x_i, neighborhood[0]) 39 | if must_linked: 40 | # TODO is it necessary? this preprocessing is part of the clustering algorithms 41 | for x_j in neighborhood: 42 | ml.append([x_i, x_j]) 43 | 44 | for other_neighborhood in neighborhoods: 45 | if neighborhood != other_neighborhood: 46 | for x_j in other_neighborhood: 47 | cl.append([x_i, x_j]) 48 | 49 | neighborhood.append(x_i) 50 | must_link_found = True 51 | break 52 | 53 | # TODO should we add the cannot-link in case the algorithm stops before it queries all neighborhoods? 54 | 55 | if not must_link_found: 56 | for neighborhood in neighborhoods: 57 | for x_j in neighborhood: 58 | cl.append([x_i, x_j]) 59 | 60 | neighborhoods.append([x_i]) 61 | 62 | except MaximumQueriesExceeded: 63 | break 64 | 65 | self.pairwise_constraints_ = ml, cl 66 | 67 | return self 68 | 69 | def _most_informative(self, X, clusterer, neighborhoods): 70 | n = X.shape[0] 71 | l = len(neighborhoods) 72 | 73 | neighborhoods_union = set() 74 | for neighborhood in neighborhoods: 75 | for i in neighborhood: 76 | neighborhoods_union.add(i) 77 | 78 | unqueried_indices = set(range(n)) - neighborhoods_union 79 | 80 | # TODO if there is only one neighborhood then choose the point randomly? 81 | if l <= 1: 82 | return np.random.choice(list(unqueried_indices)), [1] 83 | 84 | # Learn a random forest classifier 85 | n_estimators = 50 86 | rf = RandomForestClassifier(n_estimators=n_estimators) 87 | rf.fit(X, clusterer.labels_) 88 | 89 | # Compute the similarity matrix 90 | leaf_indices = rf.apply(X) 91 | S = np.zeros((n, n)) 92 | for i in range(n): 93 | for j in range(n): 94 | S[i, j] = (leaf_indices[i,] == leaf_indices[j,]).sum() 95 | S = S / n_estimators 96 | 97 | p = np.empty((n, l)) 98 | uncertainties = np.zeros(n) 99 | expected_costs = np.ones(n) 100 | 101 | # For each point that is not in any neighborhood... 102 | # TODO iterate only unqueried indices 103 | for x_i in range(n): 104 | if not x_i in neighborhoods_union: 105 | for n_i in range(l): 106 | p[x_i, n_i] = (S[x_i, neighborhoods[n_i]].sum() / len(neighborhoods[n_i])) 107 | 108 | # If the point is not similar to any neighborhood set equal probabilities of belonging to each neighborhood 109 | if np.all(p[x_i,] == 0): 110 | p[x_i,] = np.ones(l) 111 | 112 | p[x_i,] = p[x_i,] / p[x_i,].sum() 113 | 114 | if not np.any(p[x_i,] == 1): 115 | positive_p_i = p[x_i, p[x_i,] > 0] 116 | uncertainties[x_i] = -(positive_p_i * np.log2(positive_p_i)).sum() 117 | expected_costs[x_i] = (positive_p_i * range(1, len(positive_p_i) + 1)).sum() 118 | else: 119 | uncertainties[x_i] = 0 120 | expected_costs[x_i] = 1 # ? 121 | 122 | normalized_uncertainties = uncertainties / expected_costs 123 | 124 | most_informative_i = np.argmax(normalized_uncertainties) 125 | return most_informative_i, p[most_informative_i] 126 | -------------------------------------------------------------------------------- /active_semi_clustering/active/pairwise_constraints/random.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Random: 5 | def __init__(self, n_clusters=3, **kwargs): 6 | self.n_clusters = n_clusters 7 | 8 | def fit(self, X, oracle=None): 9 | constraints = [np.random.choice(range(X.shape[0]), size=2, replace=False).tolist() for _ in range(oracle.max_queries_cnt)] 10 | 11 | ml, cl = [], [] 12 | 13 | for i, j in constraints: 14 | must_linked = oracle.query(i, j) 15 | if must_linked: 16 | ml.append((i, j)) 17 | else: 18 | cl.append((i, j)) 19 | 20 | self.pairwise_constraints_ = ml, cl 21 | 22 | return self 23 | -------------------------------------------------------------------------------- /active_semi_clustering/exceptions.py: -------------------------------------------------------------------------------- 1 | class ClusteringNotFoundException(Exception): 2 | pass 3 | 4 | 5 | class EmptyClustersException(Exception): 6 | pass 7 | 8 | 9 | class InconsistentConstraintsException(Exception): 10 | pass 11 | -------------------------------------------------------------------------------- /active_semi_clustering/farthest_first_traversal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def dist(i, S, points): 5 | distances = np.array([np.sqrt(((points[i] - points[j]) ** 2).sum()) for j in S]) 6 | return distances.min() 7 | 8 | 9 | def farthest_first_traversal(points, k): 10 | traversed = [] 11 | 12 | # Choose the first point randomly 13 | i = np.random.choice(len(points)) 14 | traversed.append(i) 15 | 16 | # Find remaining n - 1 maximally separated points 17 | for _ in range(k - 1): 18 | max_dst, max_dst_index = 0, None 19 | 20 | for i in range(len(points)): 21 | if i not in traversed: 22 | dst = dist(i, traversed, points) 23 | 24 | if dst > max_dst: 25 | max_dst = dst 26 | max_dst_index = i 27 | 28 | traversed.append(max_dst_index) 29 | 30 | return traversed 31 | 32 | 33 | def weighted_farthest_first_traversal(points, weights, k): 34 | traversed = [] 35 | 36 | # Choose the first point randomly (weighted) 37 | i = np.random.choice(len(points), size=1, p=weights)[0] 38 | traversed.append(i) 39 | 40 | # Find remaining n - 1 maximally separated points 41 | for _ in range(k - 1): 42 | max_dst, max_dst_index = 0, None 43 | 44 | for i in range(len(points)): 45 | if i not in traversed: 46 | dst = dist(i, traversed, points) 47 | weighted_dst = weights[i] * dst 48 | 49 | if weighted_dst > max_dst: 50 | max_dst = weighted_dst 51 | max_dst_index = i 52 | 53 | traversed.append(max_dst_index) 54 | 55 | return traversed 56 | -------------------------------------------------------------------------------- /active_semi_clustering/semi_supervised/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamole-ai/active-semi-supervised-clustering/0dcab86ea22cd66ed7ea64234efdff1c6aca7981/active_semi_clustering/semi_supervised/__init__.py -------------------------------------------------------------------------------- /active_semi_clustering/semi_supervised/labeled_data/__init__.py: -------------------------------------------------------------------------------- 1 | from .kmeans import KMeans 2 | from .seededkmeans import SeededKMeans 3 | from .constrainedkmeans import ConstrainedKMeans -------------------------------------------------------------------------------- /active_semi_clustering/semi_supervised/labeled_data/constrainedkmeans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .kmeans import EmptyClustersException 4 | from .seededkmeans import SeededKMeans 5 | 6 | 7 | class ConstrainedKMeans(SeededKMeans): 8 | def _assign_clusters(self, X, y, cluster_centers, dist): 9 | labels = np.full(X.shape[0], fill_value=-1) 10 | 11 | for i, x in enumerate(X): 12 | if y[i] != -1: 13 | labels[i] = y[i] 14 | else: 15 | labels[i] = np.argmin([dist(x, c) for c in cluster_centers]) 16 | 17 | # Handle empty clusters 18 | # See https://github.com/scikit-learn/scikit-learn/blob/0.19.1/sklearn/cluster/_k_means.pyx#L309 19 | n_samples_in_cluster = np.bincount(labels, minlength=self.n_clusters) 20 | empty_clusters = np.where(n_samples_in_cluster == 0)[0] 21 | 22 | if len(empty_clusters) > 0: 23 | raise EmptyClustersException 24 | 25 | return labels 26 | -------------------------------------------------------------------------------- /active_semi_clustering/semi_supervised/labeled_data/kmeans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from active_semi_clustering.exceptions import EmptyClustersException 4 | 5 | 6 | class KMeans: 7 | def __init__(self, n_clusters=3, max_iter=100): 8 | self.n_clusters = n_clusters 9 | self.max_iter = max_iter 10 | 11 | def fit(self, X, y=None, **kwargs): 12 | # Initialize cluster centers 13 | cluster_centers = self._init_cluster_centers(X, y) 14 | 15 | # Repeat until convergence 16 | for iteration in range(self.max_iter): 17 | prev_cluster_centers = cluster_centers.copy() 18 | 19 | # Assign clusters 20 | labels = self._assign_clusters(X, y, cluster_centers, self._dist) 21 | 22 | # Estimate means 23 | cluster_centers = self._get_cluster_centers(X, labels) 24 | 25 | # Check for convergence 26 | cluster_centers_shift = (prev_cluster_centers - cluster_centers) 27 | converged = np.allclose(cluster_centers_shift, np.zeros(cluster_centers.shape), atol=1e-6, rtol=0) 28 | 29 | if converged: break 30 | 31 | self.cluster_centers_, self.labels_ = cluster_centers, labels 32 | 33 | return self 34 | 35 | def _init_cluster_centers(self, X, y=None): 36 | return X[np.random.choice(X.shape[0], self.n_clusters, replace=False), :] 37 | 38 | def _dist(self, x, y): 39 | return np.sqrt(np.sum((x - y) ** 2)) 40 | 41 | def _assign_clusters(self, X, y, cluster_centers, dist): 42 | labels = np.full(X.shape[0], fill_value=-1) 43 | 44 | for i, x in enumerate(X): 45 | labels[i] = np.argmin([dist(x, c) for c in cluster_centers]) 46 | 47 | # Handle empty clusters 48 | # See https://github.com/scikit-learn/scikit-learn/blob/0.19.1/sklearn/cluster/_k_means.pyx#L309 49 | n_samples_in_cluster = np.bincount(labels, minlength=self.n_clusters) 50 | empty_clusters = np.where(n_samples_in_cluster == 0)[0] 51 | 52 | if len(empty_clusters) > 0: 53 | raise EmptyClustersException 54 | 55 | return labels 56 | 57 | def _get_cluster_centers(self, X, labels): 58 | return np.array([X[labels == i].mean(axis=0) for i in range(self.n_clusters)]) 59 | -------------------------------------------------------------------------------- /active_semi_clustering/semi_supervised/labeled_data/seededkmeans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .kmeans import KMeans 4 | 5 | 6 | class SeededKMeans(KMeans): 7 | def _init_cluster_centers(self, X, y=None): 8 | if np.all(y == -1): 9 | return X[np.random.choice(X.shape[0], self.n_clusters, replace=False), :] 10 | else: 11 | return self._get_cluster_centers(X, y) 12 | -------------------------------------------------------------------------------- /active_semi_clustering/semi_supervised/pairwise_constraints/__init__.py: -------------------------------------------------------------------------------- 1 | from .copkmeans import COPKMeans 2 | from .pckmeans import PCKMeans 3 | from .mpckmeans import MPCKMeans 4 | from .mpckmeansmf import MPCKMeansMF 5 | from .mkmeans import MKMeans 6 | from .rcakmeans import RCAKMeans -------------------------------------------------------------------------------- /active_semi_clustering/semi_supervised/pairwise_constraints/constraints.py: -------------------------------------------------------------------------------- 1 | from active_semi_clustering.exceptions import InconsistentConstraintsException 2 | 3 | 4 | # Taken from https://github.com/Behrouz-Babaki/COP-Kmeans/blob/master/copkmeans/cop_kmeans.py 5 | def preprocess_constraints(ml, cl, n): 6 | "Create a graph of constraints for both must- and cannot-links" 7 | 8 | # Represent the graphs using adjacency-lists 9 | ml_graph, cl_graph = {}, {} 10 | for i in range(n): 11 | ml_graph[i] = set() 12 | cl_graph[i] = set() 13 | 14 | def add_both(d, i, j): 15 | d[i].add(j) 16 | d[j].add(i) 17 | 18 | for (i, j) in ml: 19 | ml_graph[i].add(j) 20 | ml_graph[j].add(i) 21 | 22 | for (i, j) in cl: 23 | cl_graph[i].add(j) 24 | cl_graph[j].add(i) 25 | 26 | def dfs(i, graph, visited, component): 27 | visited[i] = True 28 | for j in graph[i]: 29 | if not visited[j]: 30 | dfs(j, graph, visited, component) 31 | component.append(i) 32 | 33 | # Run DFS from each node to get all the graph's components 34 | # and add an edge for each pair of nodes in the component (create a complete graph) 35 | # See http://www.techiedelight.com/transitive-closure-graph/ for more details 36 | visited = [False] * n 37 | neighborhoods = [] 38 | for i in range(n): 39 | if not visited[i] and ml_graph[i]: 40 | component = [] 41 | dfs(i, ml_graph, visited, component) 42 | for x1 in component: 43 | for x2 in component: 44 | if x1 != x2: 45 | ml_graph[x1].add(x2) 46 | neighborhoods.append(component) 47 | 48 | for (i, j) in cl: 49 | for x in ml_graph[i]: 50 | add_both(cl_graph, x, j) 51 | 52 | for y in ml_graph[j]: 53 | add_both(cl_graph, i, y) 54 | 55 | for x in ml_graph[i]: 56 | for y in ml_graph[j]: 57 | add_both(cl_graph, x, y) 58 | 59 | for i in ml_graph: 60 | for j in ml_graph[i]: 61 | if j != i and j in cl_graph[i]: 62 | raise InconsistentConstraintsException('Inconsistent constraints between {} and {}'.format(i, j)) 63 | 64 | return ml_graph, cl_graph, neighborhoods 65 | -------------------------------------------------------------------------------- /active_semi_clustering/semi_supervised/pairwise_constraints/copkmeans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from active_semi_clustering.exceptions import EmptyClustersException, ClusteringNotFoundException 4 | from .constraints import preprocess_constraints 5 | 6 | 7 | class COPKMeans: 8 | def __init__(self, n_clusters=3, max_iter=100): 9 | self.n_clusters = n_clusters 10 | self.max_iter = max_iter 11 | 12 | def fit(self, X, y=None, ml=[], cl=[]): 13 | ml_graph, cl_graph, neighborhoods = preprocess_constraints(ml, cl, X.shape[0]) 14 | 15 | # Initialize cluster centers 16 | cluster_centers = self._init_cluster_centers(X) 17 | 18 | # Repeat until convergence 19 | for iteration in range(self.max_iter): 20 | prev_cluster_centers = cluster_centers.copy() 21 | 22 | # Assign clusters 23 | labels = self._assign_clusters(X, cluster_centers, self._dist, ml_graph, cl_graph) 24 | 25 | # Estimate means 26 | cluster_centers = self._get_cluster_centers(X, labels) 27 | 28 | # Check for convergence 29 | cluster_centers_shift = (prev_cluster_centers - cluster_centers) 30 | converged = np.allclose(cluster_centers_shift, np.zeros(cluster_centers.shape), atol=1e-6, rtol=0) 31 | 32 | if converged: break 33 | 34 | self.cluster_centers_, self.labels_ = cluster_centers, labels 35 | 36 | return self 37 | 38 | def _init_cluster_centers(self, X): 39 | return X[np.random.choice(X.shape[0], self.n_clusters, replace=False), :] 40 | 41 | def _dist(self, x, y): 42 | return np.sqrt(np.sum((x - y) ** 2)) 43 | 44 | def _assign_clusters(self, *args): 45 | max_retries_cnt = 1000 46 | 47 | for retries_cnt in range(max_retries_cnt): 48 | try: 49 | return self._try_assign_clusters(*args) 50 | 51 | except ClusteringNotFoundException: 52 | continue 53 | 54 | raise ClusteringNotFoundException 55 | 56 | def _try_assign_clusters(self, X, cluster_centers, dist, ml_graph, cl_graph): 57 | labels = np.full(X.shape[0], fill_value=-1) 58 | 59 | data_indices = list(range(X.shape[0])) 60 | np.random.shuffle(data_indices) 61 | 62 | for i in data_indices: 63 | distances = np.array([dist(X[i], c) for c in cluster_centers]) 64 | # sorted_cluster_indices = np.argsort([dist(x, c) for c in cluster_centers]) 65 | 66 | for cluster_index in distances.argsort(): 67 | if not self._violates_constraints(i, cluster_index, labels, ml_graph, cl_graph): 68 | labels[i] = cluster_index 69 | break 70 | 71 | if labels[i] < 0: 72 | raise ClusteringNotFoundException 73 | 74 | # Handle empty clusters 75 | # See https://github.com/scikit-learn/scikit-learn/blob/0.19.1/sklearn/cluster/_k_means.pyx#L309 76 | n_samples_in_cluster = np.bincount(labels, minlength=self.n_clusters) 77 | empty_clusters = np.where(n_samples_in_cluster == 0)[0] 78 | 79 | if len(empty_clusters) > 0: 80 | raise EmptyClustersException 81 | 82 | return labels 83 | 84 | def _violates_constraints(self, i, cluster_index, labels, ml_graph, cl_graph): 85 | for j in ml_graph[i]: 86 | if labels[j] > 0 and cluster_index != labels[j]: 87 | return True 88 | 89 | for j in cl_graph[i]: 90 | if cluster_index == labels[j]: 91 | return True 92 | 93 | return False 94 | 95 | def _get_cluster_centers(self, X, labels): 96 | return np.array([X[labels == i].mean(axis=0) for i in range(self.n_clusters)]) 97 | -------------------------------------------------------------------------------- /active_semi_clustering/semi_supervised/pairwise_constraints/mkmeans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.cluster import KMeans 4 | from metric_learn import MMC 5 | 6 | 7 | class MKMeans: 8 | def __init__(self, n_clusters=3, max_iter=1000, diagonal=True): 9 | self.n_clusters = n_clusters 10 | self.max_iter = max_iter 11 | self.diagonal = diagonal 12 | 13 | def fit(self, X, y=None, ml=[], cl=[]): 14 | X_transformed = X 15 | 16 | if ml and cl: 17 | # ml_graph, cl_graph, _ = preprocess_constraints(ml, cl, X.shape[0]) 18 | # 19 | # ml, cl = [], [] 20 | # for i, constraints in ml_graph.items(): 21 | # for j in constraints: 22 | # ml.append((i, j)) 23 | # 24 | # for i, constraints in cl_graph.items(): 25 | # for j in constraints: 26 | # cl.append((i, j)) 27 | 28 | constraints = [np.array(lst) for lst in [*zip(*ml), *zip(*cl)]] 29 | mmc = MMC(diagonal=self.diagonal) 30 | mmc.fit(X, constraints=constraints) 31 | X_transformed = mmc.transform(X) 32 | 33 | kmeans = KMeans(n_clusters=self.n_clusters, init='random', max_iter=self.max_iter) 34 | kmeans.fit(X_transformed) 35 | 36 | self.labels_ = kmeans.labels_ 37 | 38 | return self 39 | -------------------------------------------------------------------------------- /active_semi_clustering/semi_supervised/pairwise_constraints/mpckmeans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | 4 | from active_semi_clustering.exceptions import EmptyClustersException 5 | from active_semi_clustering.farthest_first_traversal import weighted_farthest_first_traversal 6 | from .constraints import preprocess_constraints 7 | 8 | np.seterr('raise') 9 | 10 | 11 | class MPCKMeans: 12 | "MPCK-Means-S-D that learns only a single (S) diagonal (D) matrix" 13 | 14 | def __init__(self, n_clusters=3, max_iter=10, w=1): 15 | self.n_clusters = n_clusters 16 | self.max_iter = max_iter 17 | self.w = w 18 | 19 | def fit(self, X, y=None, ml=[], cl=[]): 20 | # Preprocess constraints 21 | ml_graph, cl_graph, neighborhoods = preprocess_constraints(ml, cl, X.shape[0]) 22 | 23 | # Initialize cluster centers 24 | cluster_centers = self._initialize_cluster_centers(X, neighborhoods) 25 | 26 | # Initialize metrics 27 | A = np.identity(X.shape[1]) 28 | 29 | # Repeat until convergence 30 | for iteration in range(self.max_iter): 31 | prev_cluster_centers = cluster_centers.copy() 32 | 33 | # Find farthest pair of points according to each metric 34 | farthest = self._find_farthest_pairs_of_points(X, A) 35 | 36 | # Assign clusters 37 | labels = self._assign_clusters(X, y, cluster_centers, A, farthest, ml_graph, cl_graph, self.w) 38 | 39 | # Estimate means 40 | cluster_centers = self._get_cluster_centers(X, labels) 41 | 42 | # Update metrics 43 | A = self._update_metrics(X, labels, cluster_centers, farthest, ml_graph, cl_graph, self.w) 44 | 45 | # Check for convergence 46 | cluster_centers_shift = (prev_cluster_centers - cluster_centers) 47 | converged = np.allclose(cluster_centers_shift, np.zeros(cluster_centers.shape), atol=1e-6, rtol=0) 48 | 49 | if converged: 50 | break 51 | 52 | # print('\t', iteration, converged) 53 | 54 | self.cluster_centers_, self.labels_ = cluster_centers, labels 55 | 56 | return self 57 | 58 | def _find_farthest_pairs_of_points(self, X, A): 59 | farthest = None 60 | n = X.shape[0] 61 | max_distance = 0 62 | 63 | for i in range(n): 64 | for j in range(n): 65 | if j < i: 66 | distance = self._dist(X[i], X[j], A) 67 | if distance > max_distance: 68 | max_distance = distance 69 | farthest = (i, j, distance) 70 | 71 | assert farthest is not None 72 | 73 | return farthest 74 | 75 | def _initialize_cluster_centers(self, X, neighborhoods): 76 | neighborhood_centers = np.array([X[neighborhood].mean(axis=0) for neighborhood in neighborhoods]) 77 | neighborhood_sizes = np.array([len(neighborhood) for neighborhood in neighborhoods]) 78 | neighborhood_weights = neighborhood_sizes / neighborhood_sizes.sum() 79 | 80 | # print('\t', len(neighborhoods), neighborhood_sizes) 81 | 82 | if len(neighborhoods) > self.n_clusters: 83 | cluster_centers = neighborhood_centers[weighted_farthest_first_traversal(neighborhood_centers, neighborhood_weights, self.n_clusters)] 84 | else: 85 | if len(neighborhoods) > 0: 86 | cluster_centers = neighborhood_centers 87 | else: 88 | cluster_centers = np.empty((0, X.shape[1])) 89 | 90 | if len(neighborhoods) < self.n_clusters: 91 | remaining_cluster_centers = X[np.random.choice(X.shape[0], self.n_clusters - len(neighborhoods), replace=False), :] 92 | cluster_centers = np.concatenate([cluster_centers, remaining_cluster_centers]) 93 | 94 | return cluster_centers 95 | 96 | def _dist(self, x, y, A): 97 | "(x - y)^T A (x - y)" 98 | return scipy.spatial.distance.mahalanobis(x, y, A) ** 2 99 | 100 | def _objective_fn(self, X, i, labels, cluster_centers, cluster_id, A, farthest, ml_graph, cl_graph, w): 101 | term_d = self._dist(X[i], cluster_centers[cluster_id], A) - np.log(np.linalg.det(A)) / np.log(2) # FIXME is it okay that it might be negative? 102 | 103 | def f_m(i, j, A): 104 | return self._dist(X[i], X[j], A) 105 | 106 | def f_c(i, j, A, farthest): 107 | return farthest[2] - self._dist(X[i], X[j], A) 108 | 109 | term_m = 0 110 | for j in ml_graph[i]: 111 | if labels[j] >= 0 and labels[j] != cluster_id: 112 | term_m += 2 * w * f_m(i, j, A) 113 | 114 | term_c = 0 115 | for j in cl_graph[i]: 116 | if labels[j] == cluster_id: 117 | # assert f_c(i, j, A, farthest) >= 0 118 | term_c += 2 * w * f_c(i, j, A, farthest) 119 | 120 | return term_d + term_m + term_c 121 | 122 | def _assign_clusters(self, X, y, cluster_centers, A, farthest, ml_graph, cl_graph, w): 123 | labels = np.full(X.shape[0], fill_value=-1) 124 | 125 | index = list(range(X.shape[0])) 126 | np.random.shuffle(index) 127 | for i in index: 128 | labels[i] = np.argmin([self._objective_fn(X, i, labels, cluster_centers, cluster_id, A, farthest, ml_graph, cl_graph, w) for cluster_id, cluster_center in enumerate(cluster_centers)]) 129 | 130 | # Handle empty clusters 131 | # See https://github.com/scikit-learn/scikit-learn/blob/0.19.1/sklearn/cluster/_k_means.pyx#L309 132 | n_samples_in_cluster = np.bincount(labels, minlength=self.n_clusters) 133 | empty_clusters = np.where(n_samples_in_cluster == 0)[0] 134 | 135 | if len(empty_clusters) > 0: 136 | # print("Empty clusters") 137 | raise EmptyClustersException 138 | 139 | return labels 140 | 141 | def _update_metrics(self, X, labels, cluster_centers, farthest, ml_graph, cl_graph, w): 142 | N, D = X.shape 143 | A = np.zeros((D, D)) 144 | 145 | for d in range(D): 146 | term_x = np.sum([(x[d] - cluster_centers[labels[i], d]) ** 2 for i, x in enumerate(X)]) 147 | 148 | term_m = 0 149 | for i in range(N): 150 | for j in ml_graph[i]: 151 | if labels[i] != labels[j]: 152 | term_m += 1 / 2 * w * (X[i, d] - X[j, d]) ** 2 153 | 154 | term_c = 0 155 | for i in range(N): 156 | for j in cl_graph[i]: 157 | if labels[i] == labels[j]: 158 | tmp = ((X[farthest[0], d] - X[farthest[1], d]) ** 2 - (X[i, d] - X[j, d]) ** 2) 159 | term_c += w * max(tmp, 0) 160 | 161 | # print('term_x', term_x, 'term_m', term_m, 'term_c', term_c) 162 | 163 | A[d, d] = N * 1 / max(term_x + term_m + term_c, 1e-9) 164 | 165 | return A 166 | 167 | def _get_cluster_centers(self, X, labels): 168 | return np.array([X[labels == i].mean(axis=0) for i in range(self.n_clusters)]) 169 | -------------------------------------------------------------------------------- /active_semi_clustering/semi_supervised/pairwise_constraints/mpckmeansmf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | 4 | from active_semi_clustering.exceptions import EmptyClustersException 5 | from active_semi_clustering.farthest_first_traversal import weighted_farthest_first_traversal 6 | from .constraints import preprocess_constraints 7 | 8 | 9 | # np.seterr('raise') 10 | 11 | class MPCKMeansMF: 12 | """ 13 | MPCK-Means that learns multiple (M) full (F) matrices 14 | """ 15 | 16 | def __init__(self, n_clusters=3, max_iter=100, w=1): 17 | self.n_clusters = n_clusters 18 | self.max_iter = max_iter 19 | self.w = w 20 | 21 | def fit(self, X, y=None, ml=[], cl=[]): 22 | # Preprocess constraints 23 | ml_graph, cl_graph, neighborhoods = preprocess_constraints(ml, cl, X.shape[0]) 24 | 25 | # Initialize cluster centers 26 | cluster_centers = self._initialize_cluster_centers(X, neighborhoods) 27 | 28 | # Initialize metrics 29 | As = [np.identity(X.shape[1]) for i in range(self.n_clusters)] 30 | 31 | # Repeat until convergence 32 | for iteration in range(self.max_iter): 33 | prev_cluster_centers = cluster_centers.copy() 34 | 35 | # Find farthest pair of points according to each metric 36 | farthest = self._find_farthest_pairs_of_points(X, As) 37 | 38 | # Assign clusters 39 | labels = self._assign_clusters(X, y, cluster_centers, As, farthest, ml_graph, cl_graph, self.w) 40 | 41 | # Estimate means 42 | cluster_centers = self._get_cluster_centers(X, labels) 43 | 44 | # Update metrics 45 | As = self._update_metrics(X, labels, cluster_centers, farthest, ml_graph, cl_graph, self.w) 46 | 47 | # Check for convergence 48 | cluster_centers_shift = (prev_cluster_centers - cluster_centers) 49 | converged = np.allclose(cluster_centers_shift, np.zeros(cluster_centers.shape), atol=1e-6, rtol=0) 50 | 51 | if converged: 52 | break 53 | 54 | # print('\t', iteration, converged) 55 | 56 | self.cluster_centers_, self.labels_ = cluster_centers, labels 57 | self.As_ = As 58 | 59 | return self 60 | 61 | def _find_farthest_pairs_of_points(self, X, As): 62 | farthest = [None] * self.n_clusters 63 | 64 | n = X.shape[0] 65 | for cluster_id in range(self.n_clusters): 66 | max_distance = 0 67 | 68 | for i in range(n): 69 | for j in range(n): 70 | if j < i: 71 | distance = self._dist(X[i], X[j], As[cluster_id]) 72 | if distance > max_distance: 73 | max_distance = distance 74 | farthest[cluster_id] = (i, j, distance) 75 | 76 | return farthest 77 | 78 | def _initialize_cluster_centers(self, X, neighborhoods): 79 | neighborhood_centers = np.array([X[neighborhood].mean(axis=0) for neighborhood in neighborhoods]) 80 | neighborhood_sizes = np.array([len(neighborhood) for neighborhood in neighborhoods]) 81 | neighborhood_weights = neighborhood_sizes / neighborhood_sizes.sum() 82 | 83 | # print('\t', len(neighborhoods), neighborhood_sizes) 84 | 85 | if len(neighborhoods) > self.n_clusters: 86 | cluster_centers = neighborhood_centers[weighted_farthest_first_traversal(neighborhood_centers, neighborhood_weights, self.n_clusters)] 87 | else: 88 | if len(neighborhoods) > 0: 89 | cluster_centers = neighborhood_centers 90 | else: 91 | cluster_centers = np.empty((0, X.shape[1])) 92 | 93 | if len(neighborhoods) < self.n_clusters: 94 | remaining_cluster_centers = X[np.random.choice(X.shape[0], self.n_clusters - len(neighborhoods), replace=False), :] 95 | cluster_centers = np.concatenate([cluster_centers, remaining_cluster_centers]) 96 | 97 | return cluster_centers 98 | 99 | def _dist(self, x, y, A): 100 | "(x - y)^T A (x - y)" 101 | return scipy.spatial.distance.mahalanobis(x, y, A) ** 2 102 | 103 | def _objective_function(self, X, i, labels, cluster_centers, cluster_id, As, farthest, ml_graph, cl_graph, w): 104 | term_d = self._dist(X[i], cluster_centers[cluster_id], As[cluster_id]) - np.log(max(np.linalg.det(As[cluster_id]), 1e-9)) 105 | 106 | def f_m(i, c_i, j, c_j, As): 107 | return 1 / 2 * self._dist(X[i], X[j], As[c_i]) + 1 / 2 * self._dist(X[i], X[j], As[c_j]) 108 | 109 | def f_c(i, c_i, j, c_j, As, farthest): 110 | return farthest[c_i][2] - self._dist(X[i], X[j], As[c_i]) 111 | 112 | term_m = 0 113 | for j in ml_graph[i]: 114 | if labels[j] >= 0 and labels[j] != cluster_id: 115 | term_m += 2 * w * f_m(i, cluster_id, j, labels[j], As) 116 | 117 | term_c = 0 118 | for j in cl_graph[i]: 119 | if labels[j] == cluster_id: 120 | term_c += 2 * w * f_c(i, cluster_id, j, labels[j], As, farthest) 121 | 122 | return term_d + term_m + term_c 123 | 124 | def _assign_clusters(self, X, y, cluster_centers, As, farthest, ml_graph, cl_graph, w): 125 | labels = np.full(X.shape[0], fill_value=-1) 126 | 127 | index = list(range(X.shape[0])) 128 | np.random.shuffle(index) 129 | for i in index: 130 | labels[i] = np.argmin( 131 | [self._objective_function(X, i, labels, cluster_centers, cluster_id, As, farthest, ml_graph, cl_graph, w) for cluster_id, cluster_center in enumerate(cluster_centers)]) 132 | 133 | # Handle empty clusters 134 | # See https://github.com/scikit-learn/scikit-learn/blob/0.19.1/sklearn/cluster/_k_means.pyx#L309 135 | n_samples_in_cluster = np.bincount(labels, minlength=self.n_clusters) 136 | empty_clusters = np.where(n_samples_in_cluster == 0)[0] 137 | 138 | if len(empty_clusters) > 0: 139 | # print("Empty clusters") 140 | raise EmptyClustersException 141 | 142 | return labels 143 | 144 | def _update_metrics(self, X, labels, cluster_centers, farthest, ml_graph, cl_graph, w): 145 | As = [] 146 | 147 | for cluster_id in range(self.n_clusters): 148 | X_i = X[labels == cluster_id] 149 | n = X_i.shape[0] 150 | 151 | if n == 1: 152 | As.append(np.identity(X_i.shape[1])) 153 | continue 154 | 155 | A_inv = (X_i - cluster_centers[cluster_id]).T @ (X_i - cluster_centers[cluster_id]) 156 | 157 | for i in range(X.shape[0]): 158 | for j in ml_graph[i]: 159 | if labels[i] == cluster_id or labels[j] == cluster_id: 160 | if labels[i] != labels[j]: 161 | A_inv += 1 / 2 * w * ((X[i][:, None] - X[j][:, None]) @ (X[i][:, None] - X[j][:, None]).T) 162 | 163 | for i in range(X.shape[0]): 164 | for j in cl_graph[i]: 165 | if labels[i] == cluster_id or labels[j] == cluster_id: 166 | if labels[i] == labels[j]: 167 | A_inv += w * ( 168 | ((X[farthest[cluster_id][0]][:, None] - X[farthest[cluster_id][1]][:, None]) @ (X[farthest[cluster_id][0]][:, None] - X[farthest[cluster_id][1]][:, None]).T) - ( 169 | (X[i][:, None] - X[j][:, None]) @ (X[i][:, None] - X[j][:, None]).T)) 170 | 171 | # Handle the case when the matrix is not invertible 172 | if not self._is_invertible(A_inv): 173 | # print("Not invertible") 174 | A_inv += 1e-9 * np.trace(A_inv) * np.identity(A_inv.shape[0]) 175 | 176 | A = n * np.linalg.inv(A_inv) 177 | 178 | # Is A positive semidefinite? 179 | if not np.all(np.linalg.eigvals(A) >= 0): 180 | # print("Negative definite") 181 | eigenvalues, eigenvectors = np.linalg.eigh(A) 182 | A = eigenvectors @ np.diag(np.maximum(0, eigenvalues)) @ np.linalg.inv(eigenvectors) 183 | 184 | As.append(A) 185 | 186 | return As 187 | 188 | def _get_cluster_centers(self, X, labels): 189 | return np.array([X[labels == i].mean(axis=0) for i in range(self.n_clusters)]) 190 | 191 | def _is_invertible(self, A): 192 | return A.shape[0] == A.shape[1] and np.linalg.matrix_rank(A) == A.shape[0] 193 | -------------------------------------------------------------------------------- /active_semi_clustering/semi_supervised/pairwise_constraints/pckmeans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from active_semi_clustering.exceptions import EmptyClustersException 4 | from .constraints import preprocess_constraints 5 | 6 | 7 | class PCKMeans: 8 | def __init__(self, n_clusters=3, max_iter=100, w=1): 9 | self.n_clusters = n_clusters 10 | self.max_iter = max_iter 11 | self.w = w 12 | 13 | def fit(self, X, y=None, ml=[], cl=[]): 14 | # Preprocess constraints 15 | ml_graph, cl_graph, neighborhoods = preprocess_constraints(ml, cl, X.shape[0]) 16 | 17 | # Initialize centroids 18 | cluster_centers = self._initialize_cluster_centers(X, neighborhoods) 19 | 20 | # Repeat until convergence 21 | for iteration in range(self.max_iter): 22 | # Assign clusters 23 | labels = self._assign_clusters(X, cluster_centers, ml_graph, cl_graph, self.w) 24 | 25 | # Estimate means 26 | prev_cluster_centers = cluster_centers 27 | cluster_centers = self._get_cluster_centers(X, labels) 28 | 29 | # Check for convergence 30 | difference = (prev_cluster_centers - cluster_centers) 31 | converged = np.allclose(difference, np.zeros(cluster_centers.shape), atol=1e-6, rtol=0) 32 | 33 | if converged: break 34 | 35 | self.cluster_centers_, self.labels_ = cluster_centers, labels 36 | 37 | return self 38 | 39 | def _initialize_cluster_centers(self, X, neighborhoods): 40 | neighborhood_centers = np.array([X[neighborhood].mean(axis=0) for neighborhood in neighborhoods]) 41 | neighborhood_sizes = np.array([len(neighborhood) for neighborhood in neighborhoods]) 42 | 43 | if len(neighborhoods) > self.n_clusters: 44 | # Select K largest neighborhoods' centroids 45 | cluster_centers = neighborhood_centers[np.argsort(neighborhood_sizes)[-self.n_clusters:]] 46 | else: 47 | if len(neighborhoods) > 0: 48 | cluster_centers = neighborhood_centers 49 | else: 50 | cluster_centers = np.empty((0, X.shape[1])) 51 | 52 | # FIXME look for a point that is connected by cannot-links to every neighborhood set 53 | 54 | if len(neighborhoods) < self.n_clusters: 55 | remaining_cluster_centers = X[np.random.choice(X.shape[0], self.n_clusters - len(neighborhoods), replace=False), :] 56 | cluster_centers = np.concatenate([cluster_centers, remaining_cluster_centers]) 57 | 58 | return cluster_centers 59 | 60 | def _objective_function(self, X, x_i, centroids, c_i, labels, ml_graph, cl_graph, w): 61 | distance = 1 / 2 * np.sum((X[x_i] - centroids[c_i]) ** 2) 62 | 63 | ml_penalty = 0 64 | for y_i in ml_graph[x_i]: 65 | if labels[y_i] != -1 and labels[y_i] != c_i: 66 | ml_penalty += w 67 | 68 | cl_penalty = 0 69 | for y_i in cl_graph[x_i]: 70 | if labels[y_i] == c_i: 71 | cl_penalty += w 72 | 73 | return distance + ml_penalty + cl_penalty 74 | 75 | def _assign_clusters(self, X, cluster_centers, ml_graph, cl_graph, w): 76 | labels = np.full(X.shape[0], fill_value=-1) 77 | 78 | index = list(range(X.shape[0])) 79 | np.random.shuffle(index) 80 | for x_i in index: 81 | labels[x_i] = np.argmin([self._objective_function(X, x_i, cluster_centers, c_i, labels, ml_graph, cl_graph, w) for c_i in range(self.n_clusters)]) 82 | 83 | # Handle empty clusters 84 | # See https://github.com/scikit-learn/scikit-learn/blob/0.19.1/sklearn/cluster/_k_means.pyx#L309 85 | n_samples_in_cluster = np.bincount(labels, minlength=self.n_clusters) 86 | empty_clusters = np.where(n_samples_in_cluster == 0)[0] 87 | 88 | if len(empty_clusters) > 0: 89 | # print("Empty clusters") 90 | raise EmptyClustersException 91 | 92 | return labels 93 | 94 | def _get_cluster_centers(self, X, labels): 95 | return np.array([X[labels == i].mean(axis=0) for i in range(self.n_clusters)]) 96 | -------------------------------------------------------------------------------- /active_semi_clustering/semi_supervised/pairwise_constraints/rcakmeans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.cluster import KMeans 4 | from metric_learn import RCA 5 | 6 | from .constraints import preprocess_constraints 7 | 8 | 9 | class RCAKMeans: 10 | """ 11 | Relative Components Analysis (RCA) + KMeans 12 | """ 13 | 14 | def __init__(self, n_clusters=3, max_iter=100): 15 | self.n_clusters = n_clusters 16 | self.max_iter = max_iter 17 | 18 | def fit(self, X, y=None, ml=[], cl=[]): 19 | X_transformed = X 20 | 21 | if ml: 22 | chunks = np.full(X.shape[0], -1) 23 | ml_graph, cl_graph, neighborhoods = preprocess_constraints(ml, cl, X.shape[0]) 24 | for i, neighborhood in enumerate(neighborhoods): 25 | chunks[neighborhood] = i 26 | 27 | # print(chunks) 28 | 29 | rca = RCA() 30 | rca.fit(X, chunks=chunks) 31 | X_transformed = rca.transform(X) 32 | 33 | # print(rca.metric()) 34 | 35 | kmeans = KMeans(n_clusters=self.n_clusters, max_iter=self.max_iter) 36 | kmeans.fit(X_transformed) 37 | 38 | self.labels_ = kmeans.labels_ 39 | 40 | return self 41 | -------------------------------------------------------------------------------- /examples/Active-Semi-Supervised-Clustering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Active Semi-Supervised Clustering" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from sklearn import datasets, metrics\n", 17 | "from active_semi_clustering.semi_supervised.pairwise_constraints import PCKMeans\n", 18 | "from active_semi_clustering.active.pairwise_constraints import ExampleOracle, ExploreConsolidate, MinMax" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "X, y = datasets.load_iris(return_X_y=True)" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "First, obtain some pairwise constraints from the oracle." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "# TODO implement your own oracle that will, for example, query a domain expert via GUI or CLI\n", 44 | "oracle = ExampleOracle(y, max_queries_cnt=10)\n", 45 | "\n", 46 | "active_learner = MinMax(n_clusters=3)\n", 47 | "active_learner.fit(X, oracle=oracle)\n", 48 | "pairwise_constraints = active_learner.pairwise_constraints_" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "Then, use the constraints to do the clustering." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "clusterer = PCKMeans(n_clusters=3)\n", 65 | "clusterer.fit(X, ml=pairwise_constraints[0], cl=pairwise_constraints[1])" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "Evaluate the clustering using Adjusted Rand Score." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "metrics.adjusted_rand_score(y, clusterer.labels_)" 82 | ] 83 | } 84 | ], 85 | "metadata": { 86 | "kernelspec": { 87 | "display_name": "Python 3", 88 | "language": "python", 89 | "name": "python3" 90 | }, 91 | "language_info": { 92 | "codemirror_mode": { 93 | "name": "ipython", 94 | "version": 3 95 | }, 96 | "file_extension": ".py", 97 | "mimetype": "text/x-python", 98 | "name": "python", 99 | "nbconvert_exporter": "python", 100 | "pygments_lexer": "ipython3", 101 | "version": "3.6.3" 102 | } 103 | }, 104 | "nbformat": 4, 105 | "nbformat_minor": 2 106 | } 107 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="active-semi-supervised-clustering", 8 | version="0.0.1", 9 | author="Jakub Svehla", 10 | author_email="jakub.svehla@datamole.cz", 11 | description="Active semi-supervised clustering algorithms for scikit-learn", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/datamole-ai/active-semi-supervised-clustering", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | install_requires=[ 22 | 'numpy', 23 | 'scipy', 24 | 'scikit-learn', 25 | 'metric-learn>=0.4', 26 | ] 27 | ) 28 | --------------------------------------------------------------------------------