├── .gitattributes ├── utils ├── __init__.py └── scores.py ├── src ├── __init__.py ├── utils │ ├── __init__.py │ ├── tools.py │ ├── validations.py │ ├── norms.py │ └── metrics.py ├── dataset.py ├── _kmeanspp.py ├── _discern.py ├── minibatchkmeans.py └── kmeans.py ├── README.md ├── .gitignore └── examples ├── MiniBatch++_Yale.ipynb ├── MiniBatch_BBC.ipynb └── MiniBatch_BBC_CUDA.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | .ipynb linguist-documentation -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch-based K-Means 3 | by Ali Hassani 4 | 5 | Utils 6 | """ 7 | from .scores import purity_score 8 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch-based K-Means 3 | by Ali Hassani 4 | 5 | """ 6 | 7 | from .kmeans import KMeans 8 | from .minibatchkmeans import MiniBatchKMeans 9 | from .dataset import KMeansDataset 10 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch-based K-Means 3 | by Ali Hassani 4 | 5 | Common util functions 6 | """ 7 | 8 | from .metrics import distance_matrix, self_distance_matrix, similarity_matrix, self_similarity_matrix 9 | from .norms import row_norm, squared_norm 10 | from .tools import torch_unravel_index 11 | -------------------------------------------------------------------------------- /src/utils/tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch-based K-Means 3 | by Ali Hassani 4 | 5 | Misc tools 6 | """ 7 | import torch 8 | 9 | 10 | def torch_unravel_index(index, shape): 11 | """ 12 | Unravel index for torch tensors 13 | By ModarTensai -- PyTorch Forums 14 | https://discuss.pytorch.org/u/ModarTensai 15 | 16 | Parameters 17 | --------- 18 | index : int 19 | shape : tuple 20 | 21 | Returns 22 | ------- 23 | index : tuple 24 | """ 25 | out = [] 26 | for dim in reversed(shape): 27 | out.append(index % dim) 28 | index = index // dim 29 | return tuple(reversed(out)) 30 | -------------------------------------------------------------------------------- /src/utils/validations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch-based K-Means 3 | by Ali Hassani 4 | 5 | Validation utils 6 | """ 7 | import torch 8 | 9 | 10 | def distance_validation(x): 11 | """ 12 | Clamps the distance matrix to prevent invalid values. 13 | 14 | Parameters 15 | ---------- 16 | x : torch.Tensor 17 | 18 | Returns 19 | ------- 20 | x_out : torch.Tensor 21 | """ 22 | return torch.clamp_min(x, 0.0) 23 | 24 | 25 | def similarity_validation(x): 26 | """ 27 | Clamps the similarity matrix to prevent invalid values. 28 | 29 | Parameters 30 | ---------- 31 | x : torch.Tensor 32 | 33 | Returns 34 | ------- 35 | x_out : torch.Tensor 36 | """ 37 | return torch.clamp(x, 0.0, 1.0) 38 | -------------------------------------------------------------------------------- /src/utils/norms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch-based K-Means 3 | by Ali Hassani 4 | 5 | Norm utils 6 | """ 7 | 8 | from torch.nn import functional as F 9 | 10 | 11 | def squared_norm(x): 12 | """ 13 | Computes and returns the squared norm of the input 2d tensor on dimension 1. 14 | Useful for computing euclidean distance matrix. 15 | 16 | Parameters 17 | ---------- 18 | x : torch.Tensor of shape (n, m) 19 | 20 | Returns 21 | ------- 22 | x_squared_norm : torch.Tensor of shape (n, ) 23 | """ 24 | return (x ** 2).sum(1).view(-1, 1) 25 | 26 | 27 | def row_norm(x): 28 | """ 29 | Computes and returns the row-normalized version of the input 2d tensor. 30 | Useful for computing cosine similarity matrix. 31 | 32 | Parameters 33 | ---------- 34 | x : torch.Tensor of shape (n, m) 35 | 36 | Returns 37 | ------- 38 | x_normalized : torch.Tensor of shape (n, m) 39 | """ 40 | return F.normalize(x, p=2, dim=1) 41 | -------------------------------------------------------------------------------- /utils/scores.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def purity_score(y_true, y_pred): 5 | """ 6 | Computest the purity score (clustering accuracy) 7 | 8 | Parameters 9 | ---------- 10 | y_true : torch.Tensor[int] or torch.Tensor[long] and of shape (n_samples, ) 11 | Ground truth labels 12 | 13 | y_pred : torch.Tensor[int] or torch.Tensor[long] and of shape (n_samples, ) 14 | Predicted labels 15 | 16 | Returns 17 | ------- 18 | accuracy : float 19 | """ 20 | n = y_true.size(0) 21 | unique_classes = y_true.unique() 22 | unique_clusters = y_pred.unique() 23 | num_classes = len(unique_classes) 24 | num_clusters = len(unique_clusters) 25 | class_to_idx = {int(unique_classes[i]): i for i in range(num_classes)} 26 | cluster_to_idx = {int(unique_clusters[i]): i for i in range(num_clusters)} 27 | 28 | scores = torch.zeros((num_classes, num_clusters), dtype=torch.int16, device=y_true.device) 29 | for i in range(n): 30 | scores[class_to_idx[int(y_true[i])], cluster_to_idx[int(y_pred[i])]] += 1 31 | return scores.max(0).values.sum().item() / n 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Torch-based K-Means 2 | A torch-based implementation of K-Means, MiniBatch K-Means, K-Means++ and more with customizable distance metrics, 3 | and similarity-based clustering. 4 | 5 | ## Notes 6 | Please note that this repository is still in WIP phase, but feel free to jump in. 7 | 8 | The goal is to reach the fastest and cleanest implementation of K-Means, K-Means++ and Mini-Batch K-Means using 9 | PyTorch for CUDA-enabled clustering. 10 | 11 | 12 | Here's the progress so far: 13 | 14 | :white_check_mark: K-Means 15 | 16 | :white_check_mark: Similarity-based K-Means (Spherical K-Means) 17 | 18 | :white_check_mark: Custom metrics for K-Means 19 | 20 | :white_check_mark: K-Means++ initialization 21 | 22 | :white_check_mark: DISCERN initialization 23 | 24 | :white_check_mark: Purity score 25 | 26 | :white_check_mark: MiniBatch K-Means 27 | 28 | :black_square_button: (Testing) MiniBatch K-Means++ initialization 29 | 30 | :black_square_button: (In progress)MiniBatch K-Means optimized by torch.optim 31 | 32 |    Successful implementation, much faster than the previous MiniBatch K-Means implementation, but not as accurate. -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch-based K-Means 3 | by Ali Hassani 4 | 5 | K-Means Dataset 6 | """ 7 | import random 8 | from torch.utils.data import Dataset 9 | from .utils import squared_norm, row_norm 10 | 11 | 12 | class KMeansDataset(Dataset): 13 | """ 14 | K-Means Compatible Dataset 15 | """ 16 | def __init__(self, x, metric='default', similarity_based=False): 17 | self.data = x 18 | self.data_norm = None 19 | if type(metric) is str and metric == 'default': 20 | self.data_norm = row_norm(x) if similarity_based else squared_norm(x) 21 | 22 | def random_sample(self, n): 23 | """ 24 | Returns n random samples from the dataset. 25 | """ 26 | idx = random.sample(range(self.data.size(0)), n) 27 | return self.__getitem__(idx) 28 | 29 | def __len__(self): 30 | return len(self.data) 31 | 32 | @property 33 | def dim(self): 34 | return self.data.shape[1] 35 | 36 | def __getitem__(self, idx): 37 | if self.data_norm is not None: 38 | return self.data[idx, :], self.data_norm[idx, :] 39 | return self.data[idx, :], [None for _ in range(len(idx))] 40 | -------------------------------------------------------------------------------- /src/_kmeanspp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch-based K-Means 3 | by Ali Hassani 4 | 5 | K-Means++ initializer 6 | 7 | Arthur, David, and Sergei Vassilvitskii. k-means++: The advantages of careful seeding. Stanford, 2006. 8 | Manuscript available at: http://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf 9 | """ 10 | import time 11 | import numpy as np 12 | import torch 13 | from .utils import distance_matrix, squared_norm 14 | 15 | 16 | def k_means_pp(x, n_clusters, x_norm=None): 17 | """ 18 | K-Means++ initialization 19 | 20 | Based on Scikit-Learn's implementation 21 | 22 | Parameters 23 | ---------- 24 | x : torch.Tensor of shape (n_training_samples, n_features) 25 | n_clusters : int 26 | x_norm : torch.Tensor of shape (n_training_samples, ) or NoneType 27 | 28 | Returns 29 | ------- 30 | centroids : torch.Tensor of shape (n_clusters, n_features) 31 | """ 32 | if x_norm is None: 33 | x_norm = squared_norm(x) 34 | n_samples, n_features = x.shape 35 | 36 | centroids = torch.zeros((n_clusters, n_features), dtype=x.dtype, device=x.device) 37 | 38 | n_local_trials = 2 + int(np.log(n_clusters)) 39 | 40 | initial_centroid_idx = torch.randint(low=0, high=n_samples, size=(1,), device=x.device)[0] 41 | centroids[0, :] = x[initial_centroid_idx, :] 42 | 43 | dist_mat = distance_matrix(x=centroids[0, :].unsqueeze(0), y=x, 44 | x_norm=x_norm[initial_centroid_idx, :].unsqueeze(0), y_norm=x_norm) 45 | current_potential = dist_mat.sum(1) 46 | 47 | for c in range(1, n_clusters): 48 | rand_vals = torch.rand(n_local_trials, device=x.device) * current_potential 49 | candidate_ids = torch.searchsorted(torch.cumsum(dist_mat.squeeze(0), dim=0), rand_vals) 50 | torch.clamp_max(candidate_ids, dist_mat.size(1) - 1, out=candidate_ids) 51 | 52 | distance_to_candidates = distance_matrix(x=x[candidate_ids, :], y=x, 53 | x_norm=x_norm[candidate_ids, :], y_norm=x_norm) 54 | 55 | distance_to_candidates = torch.where(dist_mat < distance_to_candidates, dist_mat, distance_to_candidates) 56 | candidates_potential = distance_to_candidates.sum(1) 57 | 58 | best_candidate = torch.argmin(candidates_potential) 59 | current_potential = candidates_potential[best_candidate] 60 | dist_mat = distance_to_candidates[best_candidate].unsqueeze(0) 61 | best_candidate = candidate_ids[best_candidate] 62 | 63 | centroids[c, :] = x[best_candidate, :] 64 | 65 | return centroids 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | # Idea 133 | .idea 134 | 135 | # CSVs 136 | *.csv 137 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch-based K-Means 3 | by Ali Hassani 4 | 5 | Metric utils 6 | """ 7 | 8 | import torch 9 | from .norms import squared_norm, row_norm 10 | from .validations import distance_validation, similarity_validation 11 | 12 | 13 | def distance_matrix(x, y, x_norm=None, y_norm=None): 14 | """ 15 | Returns the pairwise distance matrix between the two input 2d tensors. 16 | 17 | Parameters 18 | ---------- 19 | x : torch.Tensor of shape (n, m) 20 | y : torch.Tensor of shape (p, m) 21 | x_norm : torch.Tensor of shape (n, ) or NoneType 22 | y_norm : torch.Tensor of shape (p, ) or NoneType 23 | 24 | Returns 25 | ------- 26 | distance_matrix : torch.Tensor of shape (n, p) 27 | """ 28 | x_norm = squared_norm(x) if x_norm is None else x_norm 29 | y_norm = squared_norm(y).T if y_norm is None else y_norm.T 30 | mat = x_norm + y_norm - 2.0 * torch.mm(x, y.T) 31 | return distance_validation(mat) 32 | 33 | 34 | def self_distance_matrix(x): 35 | """ 36 | Returns the self distance matrix of the input 2d tensor. 37 | 38 | Parameters 39 | ---------- 40 | x : torch.Tensor of shape (n, m) 41 | 42 | Returns 43 | ------- 44 | distance_matrix : torch.Tensor of shape (n, n) 45 | """ 46 | return distance_validation(((x.unsqueeze(0) - x.unsqueeze(1)) ** 2).sum(2)) 47 | 48 | 49 | def similarity_matrix(x, y, pre_normalized=False): 50 | """ 51 | Returns the pairwise similarity matrix between the two input 2d tensors. 52 | 53 | Parameters 54 | ---------- 55 | x : torch.Tensor of shape (n, m) 56 | y : torch.Tensor of shape (p, m) 57 | pre_normalized : bool, default=False 58 | Whether the inputs are already row-normalized 59 | 60 | Returns 61 | ------- 62 | similarity_matrix : torch.Tensor of shape (n, p) 63 | """ 64 | if pre_normalized: 65 | return similarity_validation((x.matmul(y.T))) 66 | return similarity_validation((row_norm(x).matmul(row_norm(y).T))) 67 | 68 | 69 | def self_similarity_matrix(x, pre_normalized=False): 70 | """ 71 | Returns the self similarity matrix of the input 2d tensor. 72 | 73 | Parameters 74 | ---------- 75 | x : torch.Tensor of shape (n, m) 76 | pre_normalized : bool, default=False 77 | Whether the input is already row-normalized 78 | 79 | Returns 80 | ------- 81 | similarity_matrix : torch.Tensor of shape (n, n) 82 | """ 83 | if pre_normalized: 84 | return similarity_validation((1 + x.matmul(x.T)) / 2) 85 | x_normalized = row_norm(x) 86 | return similarity_validation((1 + x_normalized.matmul(x_normalized.T)) / 2) 87 | -------------------------------------------------------------------------------- /src/_discern.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch-based K-Means 3 | by Ali Hassani 4 | 5 | DISCERN initializer 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | from .utils import self_similarity_matrix, row_norm, torch_unravel_index 11 | 12 | 13 | def discern(x, n_clusters=None, max_n_clusters=None, x_norm=None): 14 | """ 15 | DISCERN initialization 16 | 17 | Parameters 18 | ---------- 19 | x : torch.Tensor of shape (n_training_samples, n_features) 20 | n_clusters : int or NoneType 21 | Estimates the number of clusters if set to None 22 | 23 | max_n_clusters : int or NoneType, default=None 24 | Defaults to n_training_samples / 2 if None 25 | 26 | x_norm : torch.Tensor of shape (n_training_samples, n_features) or NoneType 27 | 28 | Returns 29 | ------- 30 | centroids : torch.Tensor of shape (n_clusters, n_features) 31 | """ 32 | max_n_clusters = max_n_clusters if max_n_clusters is not None else int(x.size(0) / 2) + 1 33 | 34 | if x_norm is None: 35 | x_norm = row_norm(x) 36 | similarity_matrix = self_similarity_matrix(x_norm, True) 37 | 38 | centroid_idx_0, centroid_idx_1 = torch_unravel_index(int(torch.argmin(similarity_matrix)), similarity_matrix.shape) 39 | centroid_idx = [centroid_idx_0, centroid_idx_1] 40 | 41 | remaining = [y for y in range(0, len(similarity_matrix)) if y not in centroid_idx] 42 | 43 | similarity_submatrix = similarity_matrix[centroid_idx, :][:, remaining] 44 | 45 | ctr = 2 46 | max_n_clusters = len(similarity_matrix) if max_n_clusters is None else max_n_clusters 47 | find_n_clusters = n_clusters is None or n_clusters < 2 48 | 49 | membership_values = None if not find_n_clusters else np.zeros(max_n_clusters + 1, dtype=float) 50 | 51 | while len(remaining) > 1 and ctr <= max_n_clusters: 52 | if n_clusters is not None and 1 < n_clusters <= len(centroid_idx): 53 | break 54 | 55 | min_vector, max_vector = torch.min(similarity_submatrix, dim=0).values, \ 56 | torch.max(similarity_submatrix, dim=0).values 57 | diff_vector = max_vector - min_vector 58 | membership_vector = torch.square(max_vector) * min_vector * diff_vector 59 | min_idx = int(torch.argmin(membership_vector).item()) 60 | membership_vector, min_idx, min_value = membership_vector, min_idx, float(membership_vector[min_idx].data) 61 | 62 | new_centroid_idx = remaining[min_idx] 63 | if find_n_clusters: 64 | membership_values[ctr] = min_value 65 | centroid_idx.append(new_centroid_idx) 66 | remaining.remove(new_centroid_idx) 67 | similarity_submatrix = similarity_matrix[centroid_idx, :][:, remaining] 68 | ctr += 1 69 | 70 | if find_n_clusters: 71 | membership_values = membership_values[:ctr] 72 | # TODO: torch implementation 73 | rx = range(0, len(membership_values)) 74 | dy = np.gradient(membership_values, rx) 75 | d2y = np.gradient(dy, rx) 76 | kappa = (d2y / ((1 + (dy ** 2)) ** (3 / 2))) 77 | predicted_n_clusters = int(np.argmin(kappa)) 78 | n_clusters = max(predicted_n_clusters, 2) 79 | 80 | centroid_idx = centroid_idx[:n_clusters] 81 | 82 | return x[centroid_idx, :] 83 | -------------------------------------------------------------------------------- /examples/MiniBatch++_Yale.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

1. Import dependencies

" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "import numpy as np\n", 18 | "import torch\n", 19 | "from torch.utils.data import DataLoader\n", 20 | "import time\n", 21 | "\n", 22 | "from utils.scores import purity_score as purity\n", 23 | "\n", 24 | "from src import KMeans, MiniBatchKMeans, KMeansDataset\n", 25 | "\n", 26 | "from sklearn.feature_extraction.text import TfidfVectorizer" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "

2. Import data

" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "#X, Y = load_yale()\n", 43 | "\n", 44 | "X = torch.from_numpy(X)\n", 45 | "Y = torch.from_numpy(Y)\n", 46 | "\n", 47 | "if torch.cuda.is_available():\n", 48 | " X = X.cuda()\n", 49 | " Y = Y.cuda()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "similarity_based = True\n", 59 | "batch_size = 16\n", 60 | "n_clusters = 15" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 4, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "dataset = KMeansDataset(X, similarity_based=similarity_based)\n", 70 | "dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=True)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "

3. Mini-Batch K-Means

" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "metadata": { 84 | "scrolled": true 85 | }, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "Clustering finished in 2.15 seconds.\n", 92 | "[Supervised Performance] Test Accuracy: 73.33 %\n" 93 | ] 94 | } 95 | ], 96 | "source": [ 97 | "start = time.time()\n", 98 | "km = MiniBatchKMeans(n_clusters=n_clusters, n_init=1, init='random', similarity_based=similarity_based)\n", 99 | "km.fit(dataloader)\n", 100 | "labels = km.transform_tensor(X)\n", 101 | "km_time = time.time() - start\n", 102 | "acc = purity(Y, labels) * 100\n", 103 | "\n", 104 | "print(\"Clustering finished in {:.3} seconds.\".format(km_time))\n", 105 | "print(\"[Supervised Performance] Test Accuracy: {:.2f} %\".format(acc))" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "

4. Mini-Batch K-Means++

" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 6, 118 | "metadata": { 119 | "scrolled": true 120 | }, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "Clustering finished in 2.14 seconds.\n", 127 | "[Supervised Performance] Test Accuracy: 86.67 %\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "start = time.time()\n", 133 | "km = MiniBatchKMeans(n_clusters=n_clusters, n_init=1, init='k-means++', similarity_based=similarity_based)\n", 134 | "km.fit(dataloader)\n", 135 | "labels = km.transform_tensor(X)\n", 136 | "km_time = time.time() - start\n", 137 | "acc = purity(Y, labels) * 100\n", 138 | "\n", 139 | "print(\"Clustering finished in {:.3} seconds.\".format(km_time))\n", 140 | "print(\"[Supervised Performance] Test Accuracy: {:.2f} %\".format(acc))" 141 | ] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "Python 3", 147 | "language": "python", 148 | "name": "python3" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.8.5" 161 | } 162 | }, 163 | "nbformat": 4, 164 | "nbformat_minor": 2 165 | } 166 | -------------------------------------------------------------------------------- /examples/MiniBatch_BBC.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

1. Import dependencies

" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "import numpy as np\n", 18 | "import torch\n", 19 | "from torch.utils.data import DataLoader\n", 20 | "import time\n", 21 | "\n", 22 | "from sklearn.preprocessing import LabelEncoder\n", 23 | "from sklearn.metrics import silhouette_score\n", 24 | "from utils.scores import purity_score as purity\n", 25 | "\n", 26 | "from src import KMeans, MiniBatchKMeans, KMeansDataset\n", 27 | "\n", 28 | "from sklearn.feature_extraction.text import TfidfVectorizer" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "

2. Import data

" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "(2127, 29422)\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "df = pd.read_csv('bbc.csv')\n", 53 | "\n", 54 | "vectorizer = TfidfVectorizer()\n", 55 | "ds = vectorizer.fit_transform(df['content']).todense()\n", 56 | "classes = df['label'].to_numpy()\n", 57 | "\n", 58 | "idx = np.random.permutation(len(classes))\n", 59 | "ds = ds[idx, :]\n", 60 | "labels = LabelEncoder().fit_transform(np.asarray(classes))[idx]\n", 61 | "n_train = int(len(labels) * 0.7)\n", 62 | "\n", 63 | "print(ds.shape)\n", 64 | "\n", 65 | "X_train = torch.from_numpy(ds[:n_train, :])\n", 66 | "y_train = torch.from_numpy(labels[:n_train])\n", 67 | "X_test = torch.from_numpy(ds[n_train:, :])\n", 68 | "y_test = torch.from_numpy(labels[n_train:])\n", 69 | "\n", 70 | "if torch.cuda.is_available():\n", 71 | " X_train = X_train.cuda()\n", 72 | " y_train = y_train.cuda()\n", 73 | " X_test = X_test.cuda()\n", 74 | " y_test = y_test.cuda()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 3, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "similarity_based = False\n", 84 | "batch_size = 64\n", 85 | "n_clusters = 5" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "dataset = KMeansDataset(X_train, similarity_based=similarity_based)\n", 95 | "dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=True)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "

3. K-Means

" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 5, 108 | "metadata": { 109 | "scrolled": true 110 | }, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "Clustering finished in 4.41 seconds.\n", 117 | "[Supervised Performance] Test Accuracy: 65.73 %\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "start = time.time()\n", 123 | "km = KMeans(n_clusters=n_clusters, n_init=1, init='random', similarity_based=similarity_based)\n", 124 | "km.fit(X_train)\n", 125 | "labels = km.transform(X_test)\n", 126 | "km_time = time.time() - start\n", 127 | "acc = purity(y_test, labels) * 100\n", 128 | "\n", 129 | "print(\"Clustering finished in {:.3} seconds.\".format(km_time))\n", 130 | "print(\"[Supervised Performance] Test Accuracy: {:.2f} %\".format(acc))" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "

4. Mini-Batch K-Means

" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 6, 143 | "metadata": { 144 | "scrolled": true 145 | }, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "Clustering finished in 45.6 seconds.\n", 152 | "[Supervised Performance] Test Accuracy: 84.66 %\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "start = time.time()\n", 158 | "km = MiniBatchKMeans(n_clusters=n_clusters, n_init=1, init='random', similarity_based=similarity_based)\n", 159 | "km.fit(dataloader)\n", 160 | "labels = km.transform_tensor(X_test)\n", 161 | "km_time = time.time() - start\n", 162 | "acc = purity(y_test, labels) * 100\n", 163 | "\n", 164 | "print(\"Clustering finished in {:.3} seconds.\".format(km_time))\n", 165 | "print(\"[Supervised Performance] Test Accuracy: {:.2f} %\".format(acc))" 166 | ] 167 | } 168 | ], 169 | "metadata": { 170 | "kernelspec": { 171 | "display_name": "Python 3", 172 | "language": "python", 173 | "name": "python3" 174 | }, 175 | "language_info": { 176 | "codemirror_mode": { 177 | "name": "ipython", 178 | "version": 3 179 | }, 180 | "file_extension": ".py", 181 | "mimetype": "text/x-python", 182 | "name": "python", 183 | "nbconvert_exporter": "python", 184 | "pygments_lexer": "ipython3", 185 | "version": "3.8.5" 186 | } 187 | }, 188 | "nbformat": 4, 189 | "nbformat_minor": 2 190 | } 191 | -------------------------------------------------------------------------------- /examples/MiniBatch_BBC_CUDA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

1. Import dependencies

" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "import numpy as np\n", 18 | "import torch\n", 19 | "from torch.utils.data import DataLoader\n", 20 | "import time\n", 21 | "\n", 22 | "from sklearn.preprocessing import LabelEncoder\n", 23 | "from sklearn.metrics import silhouette_score\n", 24 | "from utils.scores import purity_score as purity\n", 25 | "\n", 26 | "from src import KMeans, MiniBatchKMeans, KMeansDataset\n", 27 | "\n", 28 | "from sklearn.feature_extraction.text import TfidfVectorizer" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "

2. Import data

" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "(2127, 29422)\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "df = pd.read_csv('bbc.csv')\n", 53 | "\n", 54 | "vectorizer = TfidfVectorizer()\n", 55 | "ds = vectorizer.fit_transform(df['content']).todense()\n", 56 | "classes = df['label'].to_numpy()\n", 57 | "\n", 58 | "idx = np.random.permutation(len(classes))\n", 59 | "ds = ds[idx, :]\n", 60 | "labels = LabelEncoder().fit_transform(np.asarray(classes))[idx]\n", 61 | "n_train = int(len(labels) * 0.7)\n", 62 | "\n", 63 | "print(ds.shape)\n", 64 | "\n", 65 | "X_train = torch.from_numpy(ds[:n_train, :])\n", 66 | "y_train = torch.from_numpy(labels[:n_train])\n", 67 | "X_test = torch.from_numpy(ds[n_train:, :])\n", 68 | "y_test = torch.from_numpy(labels[n_train:])\n", 69 | "\n", 70 | "if torch.cuda.is_available():\n", 71 | " X_train = X_train.cuda()\n", 72 | " y_train = y_train.cuda()\n", 73 | " X_test = X_test.cuda()\n", 74 | " y_test = y_test.cuda()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 3, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "similarity_based = False\n", 84 | "batch_size = 64\n", 85 | "n_clusters = 5" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "dataset = KMeansDataset(X_train, similarity_based=similarity_based)\n", 95 | "dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=True)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "

3. K-Means

" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 5, 108 | "metadata": { 109 | "scrolled": true 110 | }, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "Clustering finished in 0.449 seconds.\n", 117 | "[Supervised Performance] Test Accuracy: 56.34 %\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "start = time.time()\n", 123 | "km = KMeans(n_clusters=n_clusters, n_init=1, init='random', similarity_based=similarity_based)\n", 124 | "km.fit(X_train)\n", 125 | "labels = km.transform(X_test)\n", 126 | "km_time = time.time() - start\n", 127 | "acc = purity(y_test, labels) * 100\n", 128 | "\n", 129 | "print(\"Clustering finished in {:.3} seconds.\".format(km_time))\n", 130 | "print(\"[Supervised Performance] Test Accuracy: {:.2f} %\".format(acc))" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "

4. Mini-Batch K-Means

" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 6, 143 | "metadata": { 144 | "scrolled": true 145 | }, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "Clustering finished in 11.4 seconds.\n", 152 | "[Supervised Performance] Test Accuracy: 73.40 %\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "start = time.time()\n", 158 | "km = MiniBatchKMeans(n_clusters=n_clusters, n_init=1, init='random', similarity_based=similarity_based)\n", 159 | "km.fit(dataloader)\n", 160 | "labels = km.transform_tensor(X_test)\n", 161 | "km_time = time.time() - start\n", 162 | "acc = purity(y_test, labels) * 100\n", 163 | "\n", 164 | "print(\"Clustering finished in {:.3} seconds.\".format(km_time))\n", 165 | "print(\"[Supervised Performance] Test Accuracy: {:.2f} %\".format(acc))" 166 | ] 167 | } 168 | ], 169 | "metadata": { 170 | "kernelspec": { 171 | "display_name": "Python 3", 172 | "language": "python", 173 | "name": "python3" 174 | }, 175 | "language_info": { 176 | "codemirror_mode": { 177 | "name": "ipython", 178 | "version": 3 179 | }, 180 | "file_extension": ".py", 181 | "mimetype": "text/x-python", 182 | "name": "python", 183 | "nbconvert_exporter": "python", 184 | "pygments_lexer": "ipython3", 185 | "version": "3.8.5" 186 | } 187 | }, 188 | "nbformat": 4, 189 | "nbformat_minor": 2 190 | } 191 | -------------------------------------------------------------------------------- /src/minibatchkmeans.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch-based K-Means 3 | by Ali Hassani 4 | 5 | MiniBatch K-Means 6 | 7 | Sculley, David. "Web-scale k-means clustering." Proceedings of the 19th international conference on 8 | World wide web. 2010. Manuscript available at: https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf 9 | 10 | NOTE: In the original mini-batch K-Means paper by David Sculley, the SGD updates were applied per sample 11 | in a mini-batch. In this implementation, in order to make things faster, the updates are applied per 12 | mini-batch by taking a sum over the mini-batch samples within the class, similar to the original K-Means 13 | algorithm. Another difference to the original method in the paper is that one mini-batch is generated 14 | randomly at each iteration by sampling the original dataset, while in this implementation, the entire 15 | dataset is shuffled, split into mini-batches all of which are processed at each iteration. 16 | """ 17 | import numpy as np 18 | import torch 19 | from .kmeans import _BaseKMeans 20 | from ._kmeanspp import k_means_pp 21 | 22 | 23 | class MiniBatchKMeans(_BaseKMeans): 24 | """ 25 | Mini Batch K-Means 26 | 27 | Parameters 28 | ---------- 29 | n_clusters : int 30 | The number of clusters or `K`. 31 | 32 | init : 'random', 'k-means++' or torch.Tensor of shape (n_clusters, n_features) 33 | Tensor of the initial centroid coordinates, one of the pre-defined methods {'random'}. 34 | 35 | n_init : int, default=10 36 | Ignored (Number of initializations). 37 | NOTE: not yet supported. 38 | 39 | max_iter : int, default=200 40 | Maximum K-Means iterations. 41 | 42 | metric : 'default' or callable, default='default' 43 | Distance metric when similarity_based=False and similarity metric otherwise. Default is 'default' 44 | which uses L2 distance and cosine similarity as the distance and similarity metrics respectively. 45 | The callable metrics should take in two tensors of shapes (n, d) and (m, d) and return a tensor of 46 | shape (n, m). 47 | 48 | similarity_based : bool, default=False 49 | Whether the metric is a similarity metric or not. 50 | 51 | eps : float, default=1e-6 52 | Threshold for early stopping. 53 | 54 | Attributes 55 | ---------- 56 | labels_ : torch.Tensor of shape (n_training_samples,) 57 | Training cluster assignments 58 | 59 | cluster_centers_ : torch.Tensor of shape (n_clusters, n_features) 60 | Final centroid coordinates 61 | 62 | inertia_ : float 63 | Sum of squared errors when not similarity_based and sum of similarities when similarity_based 64 | 65 | n_iter_ : int 66 | The number of training iterations 67 | """ 68 | def __init__(self, n_clusters=None, init='k-means++', n_init=10, max_iter=200, metric='default', 69 | similarity_based=False, eps=1e-6): 70 | init = init if type(init) is torch.Tensor else 'k-means++' 71 | super(MiniBatchKMeans, self).__init__(n_clusters=n_clusters, init=init, max_iter=max_iter, 72 | metric=metric, similarity_based=similarity_based, eps=eps) 73 | 74 | def _initialize(self, dataloader): 75 | """ 76 | Initializes the centroid coordinates. 77 | 78 | Parameters 79 | ---------- 80 | dataloader : torch.utils.data.DataLoader[KMeansDataset] 81 | 82 | Returns 83 | ------- 84 | self 85 | """ 86 | self.labels_ = None 87 | self.inertia_ = 0 88 | self.n_iter_ = 0 89 | if self.init_method == 'k-means++': 90 | self._initialize_kpp(dataloader) 91 | elif self.init_method == 'random': 92 | self._initialize_random(dataloader) 93 | else: 94 | raise NotImplementedError("Initialization `{}` not supported.".format(self.cluster_centers_)) 95 | return self 96 | 97 | def _initialize_random(self, dataloader): 98 | """ 99 | Initializes the centroid coordinates by randomly selecting from the training samples. 100 | 101 | Parameters 102 | ---------- 103 | dataloader : torch.utils.data.DataLoader[KMeansDataset] 104 | 105 | Returns 106 | ------- 107 | self 108 | """ 109 | if type(self.n_clusters) is not int: 110 | raise NotImplementedError("Randomized K-Means expects the number of clusters, given {}.".format(type( 111 | self.n_clusters))) 112 | # TODO: Better implementation of random initialization 113 | self.cluster_centers_, self.center_norm = dataloader.dataset.random_sample(self.n_clusters) 114 | self.center_norm = self._normalize(self.cluster_centers_) 115 | return self 116 | 117 | def _initialize_kpp(self, dataloader): 118 | """ 119 | Initializes the centroid coordinates using K-Means++. 120 | 121 | Parameters 122 | ---------- 123 | dataloader : torch.utils.data.DataLoader[KMeansDataset] 124 | 125 | Returns 126 | ------- 127 | self 128 | """ 129 | # TODO: Mini-batch K-Means++ 130 | if type(self.n_clusters) is not int: 131 | raise NotImplementedError("K-Means++ expects the number of clusters, given {}.".format(type( 132 | self.n_clusters))) 133 | x, x_norm = next(iter(dataloader)) 134 | self.cluster_centers_ = k_means_pp(x, n_clusters=self.n_clusters, 135 | x_norm=x_norm if not self.similarity_based else None) 136 | return self 137 | 138 | def fit(self, dataloader): 139 | """ 140 | Initializes and fits the centroids using the samples given w.r.t the metric. 141 | 142 | Parameters 143 | ---------- 144 | dataloader : torch.utils.data.DataLoader[KMeansDataset] 145 | 146 | Returns 147 | ------- 148 | self 149 | """ 150 | self._initialize(dataloader) 151 | self.inertia_ = None 152 | self.cluster_counts = np.zeros(self.n_clusters, dtype=int) 153 | for itr in range(self.max_iter): 154 | inertia = self._fit_iter(dataloader) 155 | if self.inertia_ is not None and abs(self.inertia_ - inertia) < self.eps: 156 | self.inertia_ = inertia 157 | break 158 | self.n_iter_ = itr + 1 159 | return self 160 | 161 | def _fit_iter(self, dataloader): 162 | """ 163 | Performs one iteration of the mini-batch K-Means and updates the centroids. 164 | 165 | Parameters 166 | ---------- 167 | dataloader : torch.utils.data.DataLoader[KMeansDataset] 168 | 169 | Returns 170 | ------- 171 | self 172 | """ 173 | # TODO: Cleaner and faster implementation 174 | inertia_ = 0 175 | for x, x_norm in dataloader: 176 | labels, inertia = self._assign(x, x_norm) 177 | inertia_ += inertia 178 | for c in range(self.n_clusters): 179 | idx = torch.where(labels == c)[0] 180 | self.cluster_counts[c] += len(idx) 181 | if len(idx) > 0: 182 | lr = 1 / self.cluster_counts[c] 183 | self.cluster_centers_[c, :] = ((1 - lr) * self.cluster_centers_[c, :]) + \ 184 | (lr * torch.sum(torch.index_select(x, 0, idx), dim=0)) 185 | self.center_norm = self._normalize(self.cluster_centers_) 186 | return inertia_ 187 | 188 | def transform(self, dataloader): 189 | """ 190 | Assigns the samples in the dataloader given to the clusters w.r.t the centroid coordinates 191 | and metric. 192 | 193 | Parameters 194 | ---------- 195 | dataloader : torch.utils.data.DataLoader[KMeansDataset] 196 | 197 | Returns 198 | ------- 199 | labels : torch.Tensor of shape (n_samples,) 200 | """ 201 | label_list = [] 202 | for x, x_norm in dataloader: 203 | labels, _ = self._assign(x, x_norm) 204 | label_list.append(labels) 205 | return torch.cat(label_list) 206 | 207 | def transform_tensor(self, x): 208 | """ 209 | Assigns the samples given to the clusters w.r.t the centroid coordinates and metric. 210 | 211 | Parameters 212 | ---------- 213 | x : torch.Tensor of shape (n_samples, n_features) 214 | 215 | Returns 216 | ------- 217 | labels : torch.Tensor of shape (n_samples,) 218 | """ 219 | labels, _ = self._assign(x) 220 | return labels 221 | 222 | def fit_transform(self, dataloader): 223 | """ 224 | Fits the centroids using the samples given w.r.t the metric, returns the final assignments. 225 | 226 | Parameters 227 | ---------- 228 | dataloader : torch.utils.data.DataLoader[KMeansDataset] 229 | 230 | Returns 231 | ------- 232 | labels : torch.Tensor of shape (n_samples,) 233 | """ 234 | self.fit(dataloader) 235 | return self.labels_ 236 | -------------------------------------------------------------------------------- /src/kmeans.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch-based K-Means 3 | by Ali Hassani 4 | 5 | K-Means 6 | """ 7 | 8 | import random 9 | import numpy as np 10 | import torch 11 | from .utils import distance_matrix, similarity_matrix, squared_norm, row_norm 12 | from ._kmeanspp import k_means_pp 13 | from ._discern import discern 14 | 15 | 16 | class _BaseKMeans: 17 | """ 18 | Base K-Means : DO NOT USE DIRECTLY 19 | 20 | Parameters 21 | ---------- 22 | n_clusters : int or NoneType 23 | The number of clusters or `K`. Set to None ONLY when init = 'discern'. 24 | 25 | init : 'random', 'k-means++', 'discern', callable or torch.Tensor of shape (n_clusters, n_features) 26 | Tensor of the initial centroid coordinates, one of the pre-defined methods {'random', 'k-means++', 27 | 'discern'} or callable taking the training data as input and returning the centroid coordinates. 28 | 29 | n_init : int, default=10 30 | Ignored (Number of initializations). 31 | NOTE: not yet supported. 32 | 33 | max_iter : int, default=200 34 | Maximum K-Means iterations. 35 | 36 | metric : 'default' or callable, default='default' 37 | Distance metric when similarity_based=False and similarity metric otherwise. Default is 'default' 38 | which uses L2 distance and cosine similarity as the distance and similarity metrics respectively. 39 | WARNING: This metric does not apply to the pre-defined initialization methods (K-Means++ and DISCERN). 40 | The callable metrics should take in two tensors of shapes (n, d) and (m, d) and return a tensor of 41 | shape (n, m). 42 | 43 | similarity_based : bool, default=False 44 | Whether the metric is a similarity metric or not. 45 | 46 | eps : float, default=1e-6 47 | Threshold for early stopping. 48 | 49 | Attributes 50 | ---------- 51 | labels_ : torch.Tensor of shape (n_training_samples,) 52 | Training cluster assignments 53 | 54 | cluster_centers_ : torch.Tensor of shape (n_clusters, n_features) 55 | Final centroid coordinates 56 | 57 | inertia_ : float 58 | Sum of squared errors when not similarity_based and sum of similarities when similarity_based 59 | 60 | n_iter_ : int 61 | The number of training iterations 62 | """ 63 | def __init__(self, n_clusters=None, init='k-means++', n_init=10, max_iter=200, metric='default', 64 | similarity_based=False, eps=1e-6): 65 | self.n_clusters = n_clusters 66 | self.init_method = init if type(init) is str or callable(init) else 'k-means++' 67 | self.cluster_centers_ = init if type(init) is torch.Tensor else None 68 | self.max_iter = max_iter 69 | self.metric = metric if callable(metric) else 'default' 70 | self.similarity_based = similarity_based 71 | self.eps = eps 72 | 73 | self.center_norm = None 74 | self.labels_ = None 75 | self.inertia_ = 0 76 | self.n_iter_ = 0 77 | 78 | def _normalize(self, x): 79 | return row_norm(x) if self.similarity_based else squared_norm(x) 80 | 81 | def _assign(self, x, x_norm=None): 82 | """ 83 | Takes a set of samples and assigns them to the clusters w.r.t the centroid coordinates and metric. 84 | 85 | Parameters 86 | ---------- 87 | x : torch.Tensor of shape (n_samples, n_features) 88 | x_norm : torch.Tensor of shape (n_samples, ) or shape (n_samples, n_features), or NoneType 89 | 90 | Returns 91 | ------- 92 | labels : torch.Tensor of shape (n_samples,) 93 | """ 94 | if self.similarity_based: 95 | return self._similarity_based_assignment(x, x_norm) 96 | return self._distance_based_assignment(x, x_norm) 97 | 98 | def _distance_based_assignment(self, x, x_norm=None): 99 | """ 100 | Takes a set of samples and assigns them using the metric to the clusters w.r.t the centroid coordinates. 101 | 102 | Parameters 103 | ---------- 104 | x : torch.Tensor of shape (n_samples, n_features) 105 | x_norm : torch.Tensor of shape (n_samples, ) or NoneType 106 | 107 | Returns 108 | ------- 109 | labels : torch.Tensor of shape (n_samples,) 110 | """ 111 | if callable(self.metric): 112 | dist = self.metric(x, self.cluster_centers_) 113 | else: 114 | dist = distance_matrix(x, self.cluster_centers_, x_norm=x_norm, y_norm=self.center_norm) 115 | return torch.argmin(dist, dim=1), torch.sum(torch.min(dist, dim=1).values) 116 | 117 | def _similarity_based_assignment(self, x, x_norm): 118 | """ 119 | Takes a set of samples and assigns them using the metric to the clusters w.r.t the centroid coordinates. 120 | 121 | Parameters 122 | ---------- 123 | x : torch.Tensor of shape (n_samples, n_features) 124 | x_norm : torch.Tensor of shape (n_samples, n_features) 125 | 126 | Returns 127 | ------- 128 | labels : torch.Tensor of shape (n_samples,) 129 | """ 130 | if callable(self.metric): 131 | dist = self.metric(x, self.cluster_centers_) 132 | else: 133 | dist = similarity_matrix(x_norm if x_norm is not None else self._normalize(x), 134 | self.center_norm if self.center_norm is not None \ 135 | else self._normalize(self.cluster_centers_), 136 | pre_normalized=True) 137 | return torch.argmax(dist, dim=1), torch.sum(torch.max(dist, dim=1).values) 138 | 139 | 140 | class KMeans(_BaseKMeans): 141 | """ 142 | K-Means 143 | """ 144 | def _initialize(self, x, x_norm): 145 | """ 146 | Initializes the centroid coordinates. 147 | 148 | Parameters 149 | ---------- 150 | x : torch.Tensor of shape (n_samples, n_features) 151 | x_norm : torch.Tensor of shape (n_samples, ) or shape (n_samples, n_features), or NoneType 152 | 153 | Returns 154 | ------- 155 | self 156 | """ 157 | self.labels_ = None 158 | self.inertia_ = 0 159 | self.n_iter_ = 0 160 | if callable(self.init_method): 161 | self.cluster_centers_ = self.init_method(x) 162 | self.n_clusters = self.cluster_centers_.size(0) 163 | elif self.init_method == 'k-means++': 164 | self._initialize_kpp(x, x_norm) 165 | elif self.init_method == 'discern': 166 | self._initialize_discern(x, x_norm) 167 | elif self.init_method == 'random': 168 | self._initialize_random(x) 169 | else: 170 | raise NotImplementedError("Initialization `{}` not supported.".format(self.cluster_centers_)) 171 | self.center_norm = self._normalize(self.cluster_centers_) 172 | return self 173 | 174 | def _initialize_kpp(self, x, x_norm): 175 | """ 176 | Initializes the centroid coordinates using K-Means++. 177 | 178 | Parameters 179 | ---------- 180 | x : torch.Tensor of shape (n_samples, n_features) 181 | x_norm : torch.Tensor of shape (n_samples, ) or shape (n_samples, n_features), or NoneType 182 | 183 | Returns 184 | ------- 185 | self 186 | """ 187 | if type(self.n_clusters) is not int: 188 | raise NotImplementedError("K-Means++ expects the number of clusters, given {}.".format(type( 189 | self.n_clusters))) 190 | self.cluster_centers_ = k_means_pp(x, n_clusters=self.n_clusters, 191 | x_norm=x_norm if not self.similarity_based else None) 192 | return self 193 | 194 | def _initialize_discern(self, x, x_norm): 195 | """ 196 | Initializes the centroid coordinates using DISCERN. 197 | 198 | Parameters 199 | ---------- 200 | x : torch.Tensor of shape (n_samples, n_features) 201 | x_norm : torch.Tensor of shape (n_samples, ) or shape (n_samples, n_features), or NoneType 202 | 203 | Returns 204 | ------- 205 | self 206 | """ 207 | self.cluster_centers_ = discern(x, n_clusters=self.n_clusters, 208 | x_norm=x_norm if self.similarity_based else None) 209 | self.n_clusters = self.cluster_centers_.size(0) 210 | return self 211 | 212 | def _initialize_random(self, x): 213 | """ 214 | Initializes the centroid coordinates by randomly selecting from the training samples. 215 | 216 | Parameters 217 | ---------- 218 | x : torch.Tensor of shape (n_samples, n_features) 219 | 220 | Returns 221 | ------- 222 | self 223 | """ 224 | if type(self.n_clusters) is not int: 225 | raise NotImplementedError("Randomized K-Means expects the number of clusters, given {}.".format(type( 226 | self.n_clusters))) 227 | self.cluster_centers_ = x[random.sample(range(x.size(0)), self.n_clusters), :] 228 | return self 229 | 230 | def fit(self, x): 231 | """ 232 | Initializes and fits the centroids using the samples given w.r.t the metric. 233 | 234 | Parameters 235 | ---------- 236 | x : torch.Tensor of shape (n_samples, n_features) 237 | 238 | Returns 239 | ------- 240 | self 241 | """ 242 | x_norm = self._normalize(x) 243 | self._initialize(x, x_norm) 244 | self.inertia_ = None 245 | for itr in range(self.max_iter): 246 | labels, inertia = self._assign(x, x_norm) 247 | if self.inertia_ is not None and abs(self.inertia_ - inertia) < self.eps: 248 | self.labels_ = labels 249 | self.inertia_ = inertia 250 | break 251 | self.labels_ = labels 252 | self.inertia_ = inertia 253 | for c in range(self.n_clusters): 254 | idx = torch.where(labels == c)[0] 255 | self.cluster_centers_[c, :] = torch.mean(torch.index_select(x, 0, idx), dim=0) 256 | self.center_norm = self._normalize(self.cluster_centers_) 257 | self.n_iter_ = itr + 1 258 | return self 259 | 260 | def transform(self, x): 261 | """ 262 | Assigns the samples given to the clusters w.r.t the centroid coordinates and metric. 263 | 264 | Parameters 265 | ---------- 266 | x : torch.Tensor of shape (n_samples, n_features) 267 | 268 | Returns 269 | ------- 270 | labels : torch.Tensor of shape (n_samples,) 271 | """ 272 | labels, _ = self._assign(x, self._normalize(x)) 273 | return labels 274 | 275 | def fit_transform(self, x): 276 | """ 277 | Fits the centroids using the samples given w.r.t the metric, returns the final assignments. 278 | 279 | Parameters 280 | ---------- 281 | x : torch.Tensor of shape (n_samples, n_features) 282 | 283 | Returns 284 | ------- 285 | labels : torch.Tensor of shape (n_samples,) 286 | """ 287 | self.fit(x) 288 | return self.labels_ 289 | --------------------------------------------------------------------------------