├── .gitattributes
├── utils
├── __init__.py
└── scores.py
├── src
├── __init__.py
├── utils
│ ├── __init__.py
│ ├── tools.py
│ ├── validations.py
│ ├── norms.py
│ └── metrics.py
├── dataset.py
├── _kmeanspp.py
├── _discern.py
├── minibatchkmeans.py
└── kmeans.py
├── README.md
├── .gitignore
└── examples
├── MiniBatch++_Yale.ipynb
├── MiniBatch_BBC.ipynb
└── MiniBatch_BBC_CUDA.ipynb
/.gitattributes:
--------------------------------------------------------------------------------
1 | .ipynb linguist-documentation
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch-based K-Means
3 | by Ali Hassani
4 |
5 | Utils
6 | """
7 | from .scores import purity_score
8 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch-based K-Means
3 | by Ali Hassani
4 |
5 | """
6 |
7 | from .kmeans import KMeans
8 | from .minibatchkmeans import MiniBatchKMeans
9 | from .dataset import KMeansDataset
10 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch-based K-Means
3 | by Ali Hassani
4 |
5 | Common util functions
6 | """
7 |
8 | from .metrics import distance_matrix, self_distance_matrix, similarity_matrix, self_similarity_matrix
9 | from .norms import row_norm, squared_norm
10 | from .tools import torch_unravel_index
11 |
--------------------------------------------------------------------------------
/src/utils/tools.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch-based K-Means
3 | by Ali Hassani
4 |
5 | Misc tools
6 | """
7 | import torch
8 |
9 |
10 | def torch_unravel_index(index, shape):
11 | """
12 | Unravel index for torch tensors
13 | By ModarTensai -- PyTorch Forums
14 | https://discuss.pytorch.org/u/ModarTensai
15 |
16 | Parameters
17 | ---------
18 | index : int
19 | shape : tuple
20 |
21 | Returns
22 | -------
23 | index : tuple
24 | """
25 | out = []
26 | for dim in reversed(shape):
27 | out.append(index % dim)
28 | index = index // dim
29 | return tuple(reversed(out))
30 |
--------------------------------------------------------------------------------
/src/utils/validations.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch-based K-Means
3 | by Ali Hassani
4 |
5 | Validation utils
6 | """
7 | import torch
8 |
9 |
10 | def distance_validation(x):
11 | """
12 | Clamps the distance matrix to prevent invalid values.
13 |
14 | Parameters
15 | ----------
16 | x : torch.Tensor
17 |
18 | Returns
19 | -------
20 | x_out : torch.Tensor
21 | """
22 | return torch.clamp_min(x, 0.0)
23 |
24 |
25 | def similarity_validation(x):
26 | """
27 | Clamps the similarity matrix to prevent invalid values.
28 |
29 | Parameters
30 | ----------
31 | x : torch.Tensor
32 |
33 | Returns
34 | -------
35 | x_out : torch.Tensor
36 | """
37 | return torch.clamp(x, 0.0, 1.0)
38 |
--------------------------------------------------------------------------------
/src/utils/norms.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch-based K-Means
3 | by Ali Hassani
4 |
5 | Norm utils
6 | """
7 |
8 | from torch.nn import functional as F
9 |
10 |
11 | def squared_norm(x):
12 | """
13 | Computes and returns the squared norm of the input 2d tensor on dimension 1.
14 | Useful for computing euclidean distance matrix.
15 |
16 | Parameters
17 | ----------
18 | x : torch.Tensor of shape (n, m)
19 |
20 | Returns
21 | -------
22 | x_squared_norm : torch.Tensor of shape (n, )
23 | """
24 | return (x ** 2).sum(1).view(-1, 1)
25 |
26 |
27 | def row_norm(x):
28 | """
29 | Computes and returns the row-normalized version of the input 2d tensor.
30 | Useful for computing cosine similarity matrix.
31 |
32 | Parameters
33 | ----------
34 | x : torch.Tensor of shape (n, m)
35 |
36 | Returns
37 | -------
38 | x_normalized : torch.Tensor of shape (n, m)
39 | """
40 | return F.normalize(x, p=2, dim=1)
41 |
--------------------------------------------------------------------------------
/utils/scores.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def purity_score(y_true, y_pred):
5 | """
6 | Computest the purity score (clustering accuracy)
7 |
8 | Parameters
9 | ----------
10 | y_true : torch.Tensor[int] or torch.Tensor[long] and of shape (n_samples, )
11 | Ground truth labels
12 |
13 | y_pred : torch.Tensor[int] or torch.Tensor[long] and of shape (n_samples, )
14 | Predicted labels
15 |
16 | Returns
17 | -------
18 | accuracy : float
19 | """
20 | n = y_true.size(0)
21 | unique_classes = y_true.unique()
22 | unique_clusters = y_pred.unique()
23 | num_classes = len(unique_classes)
24 | num_clusters = len(unique_clusters)
25 | class_to_idx = {int(unique_classes[i]): i for i in range(num_classes)}
26 | cluster_to_idx = {int(unique_clusters[i]): i for i in range(num_clusters)}
27 |
28 | scores = torch.zeros((num_classes, num_clusters), dtype=torch.int16, device=y_true.device)
29 | for i in range(n):
30 | scores[class_to_idx[int(y_true[i])], cluster_to_idx[int(y_pred[i])]] += 1
31 | return scores.max(0).values.sum().item() / n
32 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Torch-based K-Means
2 | A torch-based implementation of K-Means, MiniBatch K-Means, K-Means++ and more with customizable distance metrics,
3 | and similarity-based clustering.
4 |
5 | ## Notes
6 | Please note that this repository is still in WIP phase, but feel free to jump in.
7 |
8 | The goal is to reach the fastest and cleanest implementation of K-Means, K-Means++ and Mini-Batch K-Means using
9 | PyTorch for CUDA-enabled clustering.
10 |
11 |
12 | Here's the progress so far:
13 |
14 | :white_check_mark: K-Means
15 |
16 | :white_check_mark: Similarity-based K-Means (Spherical K-Means)
17 |
18 | :white_check_mark: Custom metrics for K-Means
19 |
20 | :white_check_mark: K-Means++ initialization
21 |
22 | :white_check_mark: DISCERN initialization
23 |
24 | :white_check_mark: Purity score
25 |
26 | :white_check_mark: MiniBatch K-Means
27 |
28 | :black_square_button: (Testing) MiniBatch K-Means++ initialization
29 |
30 | :black_square_button: (In progress)MiniBatch K-Means optimized by torch.optim
31 |
32 | Successful implementation, much faster than the previous MiniBatch K-Means implementation, but not as accurate.
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch-based K-Means
3 | by Ali Hassani
4 |
5 | K-Means Dataset
6 | """
7 | import random
8 | from torch.utils.data import Dataset
9 | from .utils import squared_norm, row_norm
10 |
11 |
12 | class KMeansDataset(Dataset):
13 | """
14 | K-Means Compatible Dataset
15 | """
16 | def __init__(self, x, metric='default', similarity_based=False):
17 | self.data = x
18 | self.data_norm = None
19 | if type(metric) is str and metric == 'default':
20 | self.data_norm = row_norm(x) if similarity_based else squared_norm(x)
21 |
22 | def random_sample(self, n):
23 | """
24 | Returns n random samples from the dataset.
25 | """
26 | idx = random.sample(range(self.data.size(0)), n)
27 | return self.__getitem__(idx)
28 |
29 | def __len__(self):
30 | return len(self.data)
31 |
32 | @property
33 | def dim(self):
34 | return self.data.shape[1]
35 |
36 | def __getitem__(self, idx):
37 | if self.data_norm is not None:
38 | return self.data[idx, :], self.data_norm[idx, :]
39 | return self.data[idx, :], [None for _ in range(len(idx))]
40 |
--------------------------------------------------------------------------------
/src/_kmeanspp.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch-based K-Means
3 | by Ali Hassani
4 |
5 | K-Means++ initializer
6 |
7 | Arthur, David, and Sergei Vassilvitskii. k-means++: The advantages of careful seeding. Stanford, 2006.
8 | Manuscript available at: http://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf
9 | """
10 | import time
11 | import numpy as np
12 | import torch
13 | from .utils import distance_matrix, squared_norm
14 |
15 |
16 | def k_means_pp(x, n_clusters, x_norm=None):
17 | """
18 | K-Means++ initialization
19 |
20 | Based on Scikit-Learn's implementation
21 |
22 | Parameters
23 | ----------
24 | x : torch.Tensor of shape (n_training_samples, n_features)
25 | n_clusters : int
26 | x_norm : torch.Tensor of shape (n_training_samples, ) or NoneType
27 |
28 | Returns
29 | -------
30 | centroids : torch.Tensor of shape (n_clusters, n_features)
31 | """
32 | if x_norm is None:
33 | x_norm = squared_norm(x)
34 | n_samples, n_features = x.shape
35 |
36 | centroids = torch.zeros((n_clusters, n_features), dtype=x.dtype, device=x.device)
37 |
38 | n_local_trials = 2 + int(np.log(n_clusters))
39 |
40 | initial_centroid_idx = torch.randint(low=0, high=n_samples, size=(1,), device=x.device)[0]
41 | centroids[0, :] = x[initial_centroid_idx, :]
42 |
43 | dist_mat = distance_matrix(x=centroids[0, :].unsqueeze(0), y=x,
44 | x_norm=x_norm[initial_centroid_idx, :].unsqueeze(0), y_norm=x_norm)
45 | current_potential = dist_mat.sum(1)
46 |
47 | for c in range(1, n_clusters):
48 | rand_vals = torch.rand(n_local_trials, device=x.device) * current_potential
49 | candidate_ids = torch.searchsorted(torch.cumsum(dist_mat.squeeze(0), dim=0), rand_vals)
50 | torch.clamp_max(candidate_ids, dist_mat.size(1) - 1, out=candidate_ids)
51 |
52 | distance_to_candidates = distance_matrix(x=x[candidate_ids, :], y=x,
53 | x_norm=x_norm[candidate_ids, :], y_norm=x_norm)
54 |
55 | distance_to_candidates = torch.where(dist_mat < distance_to_candidates, dist_mat, distance_to_candidates)
56 | candidates_potential = distance_to_candidates.sum(1)
57 |
58 | best_candidate = torch.argmin(candidates_potential)
59 | current_potential = candidates_potential[best_candidate]
60 | dist_mat = distance_to_candidates[best_candidate].unsqueeze(0)
61 | best_candidate = candidate_ids[best_candidate]
62 |
63 | centroids[c, :] = x[best_candidate, :]
64 |
65 | return centroids
66 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 |
132 | # Idea
133 | .idea
134 |
135 | # CSVs
136 | *.csv
137 |
--------------------------------------------------------------------------------
/src/utils/metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch-based K-Means
3 | by Ali Hassani
4 |
5 | Metric utils
6 | """
7 |
8 | import torch
9 | from .norms import squared_norm, row_norm
10 | from .validations import distance_validation, similarity_validation
11 |
12 |
13 | def distance_matrix(x, y, x_norm=None, y_norm=None):
14 | """
15 | Returns the pairwise distance matrix between the two input 2d tensors.
16 |
17 | Parameters
18 | ----------
19 | x : torch.Tensor of shape (n, m)
20 | y : torch.Tensor of shape (p, m)
21 | x_norm : torch.Tensor of shape (n, ) or NoneType
22 | y_norm : torch.Tensor of shape (p, ) or NoneType
23 |
24 | Returns
25 | -------
26 | distance_matrix : torch.Tensor of shape (n, p)
27 | """
28 | x_norm = squared_norm(x) if x_norm is None else x_norm
29 | y_norm = squared_norm(y).T if y_norm is None else y_norm.T
30 | mat = x_norm + y_norm - 2.0 * torch.mm(x, y.T)
31 | return distance_validation(mat)
32 |
33 |
34 | def self_distance_matrix(x):
35 | """
36 | Returns the self distance matrix of the input 2d tensor.
37 |
38 | Parameters
39 | ----------
40 | x : torch.Tensor of shape (n, m)
41 |
42 | Returns
43 | -------
44 | distance_matrix : torch.Tensor of shape (n, n)
45 | """
46 | return distance_validation(((x.unsqueeze(0) - x.unsqueeze(1)) ** 2).sum(2))
47 |
48 |
49 | def similarity_matrix(x, y, pre_normalized=False):
50 | """
51 | Returns the pairwise similarity matrix between the two input 2d tensors.
52 |
53 | Parameters
54 | ----------
55 | x : torch.Tensor of shape (n, m)
56 | y : torch.Tensor of shape (p, m)
57 | pre_normalized : bool, default=False
58 | Whether the inputs are already row-normalized
59 |
60 | Returns
61 | -------
62 | similarity_matrix : torch.Tensor of shape (n, p)
63 | """
64 | if pre_normalized:
65 | return similarity_validation((x.matmul(y.T)))
66 | return similarity_validation((row_norm(x).matmul(row_norm(y).T)))
67 |
68 |
69 | def self_similarity_matrix(x, pre_normalized=False):
70 | """
71 | Returns the self similarity matrix of the input 2d tensor.
72 |
73 | Parameters
74 | ----------
75 | x : torch.Tensor of shape (n, m)
76 | pre_normalized : bool, default=False
77 | Whether the input is already row-normalized
78 |
79 | Returns
80 | -------
81 | similarity_matrix : torch.Tensor of shape (n, n)
82 | """
83 | if pre_normalized:
84 | return similarity_validation((1 + x.matmul(x.T)) / 2)
85 | x_normalized = row_norm(x)
86 | return similarity_validation((1 + x_normalized.matmul(x_normalized.T)) / 2)
87 |
--------------------------------------------------------------------------------
/src/_discern.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch-based K-Means
3 | by Ali Hassani
4 |
5 | DISCERN initializer
6 | """
7 |
8 | import numpy as np
9 | import torch
10 | from .utils import self_similarity_matrix, row_norm, torch_unravel_index
11 |
12 |
13 | def discern(x, n_clusters=None, max_n_clusters=None, x_norm=None):
14 | """
15 | DISCERN initialization
16 |
17 | Parameters
18 | ----------
19 | x : torch.Tensor of shape (n_training_samples, n_features)
20 | n_clusters : int or NoneType
21 | Estimates the number of clusters if set to None
22 |
23 | max_n_clusters : int or NoneType, default=None
24 | Defaults to n_training_samples / 2 if None
25 |
26 | x_norm : torch.Tensor of shape (n_training_samples, n_features) or NoneType
27 |
28 | Returns
29 | -------
30 | centroids : torch.Tensor of shape (n_clusters, n_features)
31 | """
32 | max_n_clusters = max_n_clusters if max_n_clusters is not None else int(x.size(0) / 2) + 1
33 |
34 | if x_norm is None:
35 | x_norm = row_norm(x)
36 | similarity_matrix = self_similarity_matrix(x_norm, True)
37 |
38 | centroid_idx_0, centroid_idx_1 = torch_unravel_index(int(torch.argmin(similarity_matrix)), similarity_matrix.shape)
39 | centroid_idx = [centroid_idx_0, centroid_idx_1]
40 |
41 | remaining = [y for y in range(0, len(similarity_matrix)) if y not in centroid_idx]
42 |
43 | similarity_submatrix = similarity_matrix[centroid_idx, :][:, remaining]
44 |
45 | ctr = 2
46 | max_n_clusters = len(similarity_matrix) if max_n_clusters is None else max_n_clusters
47 | find_n_clusters = n_clusters is None or n_clusters < 2
48 |
49 | membership_values = None if not find_n_clusters else np.zeros(max_n_clusters + 1, dtype=float)
50 |
51 | while len(remaining) > 1 and ctr <= max_n_clusters:
52 | if n_clusters is not None and 1 < n_clusters <= len(centroid_idx):
53 | break
54 |
55 | min_vector, max_vector = torch.min(similarity_submatrix, dim=0).values, \
56 | torch.max(similarity_submatrix, dim=0).values
57 | diff_vector = max_vector - min_vector
58 | membership_vector = torch.square(max_vector) * min_vector * diff_vector
59 | min_idx = int(torch.argmin(membership_vector).item())
60 | membership_vector, min_idx, min_value = membership_vector, min_idx, float(membership_vector[min_idx].data)
61 |
62 | new_centroid_idx = remaining[min_idx]
63 | if find_n_clusters:
64 | membership_values[ctr] = min_value
65 | centroid_idx.append(new_centroid_idx)
66 | remaining.remove(new_centroid_idx)
67 | similarity_submatrix = similarity_matrix[centroid_idx, :][:, remaining]
68 | ctr += 1
69 |
70 | if find_n_clusters:
71 | membership_values = membership_values[:ctr]
72 | # TODO: torch implementation
73 | rx = range(0, len(membership_values))
74 | dy = np.gradient(membership_values, rx)
75 | d2y = np.gradient(dy, rx)
76 | kappa = (d2y / ((1 + (dy ** 2)) ** (3 / 2)))
77 | predicted_n_clusters = int(np.argmin(kappa))
78 | n_clusters = max(predicted_n_clusters, 2)
79 |
80 | centroid_idx = centroid_idx[:n_clusters]
81 |
82 | return x[centroid_idx, :]
83 |
--------------------------------------------------------------------------------
/examples/MiniBatch++_Yale.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "
1. Import dependencies
"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import pandas as pd\n",
17 | "import numpy as np\n",
18 | "import torch\n",
19 | "from torch.utils.data import DataLoader\n",
20 | "import time\n",
21 | "\n",
22 | "from utils.scores import purity_score as purity\n",
23 | "\n",
24 | "from src import KMeans, MiniBatchKMeans, KMeansDataset\n",
25 | "\n",
26 | "from sklearn.feature_extraction.text import TfidfVectorizer"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {},
32 | "source": [
33 | "2. Import data
"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 2,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "#X, Y = load_yale()\n",
43 | "\n",
44 | "X = torch.from_numpy(X)\n",
45 | "Y = torch.from_numpy(Y)\n",
46 | "\n",
47 | "if torch.cuda.is_available():\n",
48 | " X = X.cuda()\n",
49 | " Y = Y.cuda()"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 3,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "similarity_based = True\n",
59 | "batch_size = 16\n",
60 | "n_clusters = 15"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 4,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "dataset = KMeansDataset(X, similarity_based=similarity_based)\n",
70 | "dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=True)"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {},
76 | "source": [
77 | "3. Mini-Batch K-Means
"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 5,
83 | "metadata": {
84 | "scrolled": true
85 | },
86 | "outputs": [
87 | {
88 | "name": "stdout",
89 | "output_type": "stream",
90 | "text": [
91 | "Clustering finished in 2.15 seconds.\n",
92 | "[Supervised Performance] Test Accuracy: 73.33 %\n"
93 | ]
94 | }
95 | ],
96 | "source": [
97 | "start = time.time()\n",
98 | "km = MiniBatchKMeans(n_clusters=n_clusters, n_init=1, init='random', similarity_based=similarity_based)\n",
99 | "km.fit(dataloader)\n",
100 | "labels = km.transform_tensor(X)\n",
101 | "km_time = time.time() - start\n",
102 | "acc = purity(Y, labels) * 100\n",
103 | "\n",
104 | "print(\"Clustering finished in {:.3} seconds.\".format(km_time))\n",
105 | "print(\"[Supervised Performance] Test Accuracy: {:.2f} %\".format(acc))"
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "metadata": {},
111 | "source": [
112 | "4. Mini-Batch K-Means++
"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 6,
118 | "metadata": {
119 | "scrolled": true
120 | },
121 | "outputs": [
122 | {
123 | "name": "stdout",
124 | "output_type": "stream",
125 | "text": [
126 | "Clustering finished in 2.14 seconds.\n",
127 | "[Supervised Performance] Test Accuracy: 86.67 %\n"
128 | ]
129 | }
130 | ],
131 | "source": [
132 | "start = time.time()\n",
133 | "km = MiniBatchKMeans(n_clusters=n_clusters, n_init=1, init='k-means++', similarity_based=similarity_based)\n",
134 | "km.fit(dataloader)\n",
135 | "labels = km.transform_tensor(X)\n",
136 | "km_time = time.time() - start\n",
137 | "acc = purity(Y, labels) * 100\n",
138 | "\n",
139 | "print(\"Clustering finished in {:.3} seconds.\".format(km_time))\n",
140 | "print(\"[Supervised Performance] Test Accuracy: {:.2f} %\".format(acc))"
141 | ]
142 | }
143 | ],
144 | "metadata": {
145 | "kernelspec": {
146 | "display_name": "Python 3",
147 | "language": "python",
148 | "name": "python3"
149 | },
150 | "language_info": {
151 | "codemirror_mode": {
152 | "name": "ipython",
153 | "version": 3
154 | },
155 | "file_extension": ".py",
156 | "mimetype": "text/x-python",
157 | "name": "python",
158 | "nbconvert_exporter": "python",
159 | "pygments_lexer": "ipython3",
160 | "version": "3.8.5"
161 | }
162 | },
163 | "nbformat": 4,
164 | "nbformat_minor": 2
165 | }
166 |
--------------------------------------------------------------------------------
/examples/MiniBatch_BBC.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "1. Import dependencies
"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import pandas as pd\n",
17 | "import numpy as np\n",
18 | "import torch\n",
19 | "from torch.utils.data import DataLoader\n",
20 | "import time\n",
21 | "\n",
22 | "from sklearn.preprocessing import LabelEncoder\n",
23 | "from sklearn.metrics import silhouette_score\n",
24 | "from utils.scores import purity_score as purity\n",
25 | "\n",
26 | "from src import KMeans, MiniBatchKMeans, KMeansDataset\n",
27 | "\n",
28 | "from sklearn.feature_extraction.text import TfidfVectorizer"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "2. Import data
"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 2,
41 | "metadata": {},
42 | "outputs": [
43 | {
44 | "name": "stdout",
45 | "output_type": "stream",
46 | "text": [
47 | "(2127, 29422)\n"
48 | ]
49 | }
50 | ],
51 | "source": [
52 | "df = pd.read_csv('bbc.csv')\n",
53 | "\n",
54 | "vectorizer = TfidfVectorizer()\n",
55 | "ds = vectorizer.fit_transform(df['content']).todense()\n",
56 | "classes = df['label'].to_numpy()\n",
57 | "\n",
58 | "idx = np.random.permutation(len(classes))\n",
59 | "ds = ds[idx, :]\n",
60 | "labels = LabelEncoder().fit_transform(np.asarray(classes))[idx]\n",
61 | "n_train = int(len(labels) * 0.7)\n",
62 | "\n",
63 | "print(ds.shape)\n",
64 | "\n",
65 | "X_train = torch.from_numpy(ds[:n_train, :])\n",
66 | "y_train = torch.from_numpy(labels[:n_train])\n",
67 | "X_test = torch.from_numpy(ds[n_train:, :])\n",
68 | "y_test = torch.from_numpy(labels[n_train:])\n",
69 | "\n",
70 | "if torch.cuda.is_available():\n",
71 | " X_train = X_train.cuda()\n",
72 | " y_train = y_train.cuda()\n",
73 | " X_test = X_test.cuda()\n",
74 | " y_test = y_test.cuda()"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 3,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "similarity_based = False\n",
84 | "batch_size = 64\n",
85 | "n_clusters = 5"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 4,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "dataset = KMeansDataset(X_train, similarity_based=similarity_based)\n",
95 | "dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=True)"
96 | ]
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "metadata": {},
101 | "source": [
102 | "3. K-Means
"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 5,
108 | "metadata": {
109 | "scrolled": true
110 | },
111 | "outputs": [
112 | {
113 | "name": "stdout",
114 | "output_type": "stream",
115 | "text": [
116 | "Clustering finished in 4.41 seconds.\n",
117 | "[Supervised Performance] Test Accuracy: 65.73 %\n"
118 | ]
119 | }
120 | ],
121 | "source": [
122 | "start = time.time()\n",
123 | "km = KMeans(n_clusters=n_clusters, n_init=1, init='random', similarity_based=similarity_based)\n",
124 | "km.fit(X_train)\n",
125 | "labels = km.transform(X_test)\n",
126 | "km_time = time.time() - start\n",
127 | "acc = purity(y_test, labels) * 100\n",
128 | "\n",
129 | "print(\"Clustering finished in {:.3} seconds.\".format(km_time))\n",
130 | "print(\"[Supervised Performance] Test Accuracy: {:.2f} %\".format(acc))"
131 | ]
132 | },
133 | {
134 | "cell_type": "markdown",
135 | "metadata": {},
136 | "source": [
137 | "4. Mini-Batch K-Means
"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 6,
143 | "metadata": {
144 | "scrolled": true
145 | },
146 | "outputs": [
147 | {
148 | "name": "stdout",
149 | "output_type": "stream",
150 | "text": [
151 | "Clustering finished in 45.6 seconds.\n",
152 | "[Supervised Performance] Test Accuracy: 84.66 %\n"
153 | ]
154 | }
155 | ],
156 | "source": [
157 | "start = time.time()\n",
158 | "km = MiniBatchKMeans(n_clusters=n_clusters, n_init=1, init='random', similarity_based=similarity_based)\n",
159 | "km.fit(dataloader)\n",
160 | "labels = km.transform_tensor(X_test)\n",
161 | "km_time = time.time() - start\n",
162 | "acc = purity(y_test, labels) * 100\n",
163 | "\n",
164 | "print(\"Clustering finished in {:.3} seconds.\".format(km_time))\n",
165 | "print(\"[Supervised Performance] Test Accuracy: {:.2f} %\".format(acc))"
166 | ]
167 | }
168 | ],
169 | "metadata": {
170 | "kernelspec": {
171 | "display_name": "Python 3",
172 | "language": "python",
173 | "name": "python3"
174 | },
175 | "language_info": {
176 | "codemirror_mode": {
177 | "name": "ipython",
178 | "version": 3
179 | },
180 | "file_extension": ".py",
181 | "mimetype": "text/x-python",
182 | "name": "python",
183 | "nbconvert_exporter": "python",
184 | "pygments_lexer": "ipython3",
185 | "version": "3.8.5"
186 | }
187 | },
188 | "nbformat": 4,
189 | "nbformat_minor": 2
190 | }
191 |
--------------------------------------------------------------------------------
/examples/MiniBatch_BBC_CUDA.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "1. Import dependencies
"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import pandas as pd\n",
17 | "import numpy as np\n",
18 | "import torch\n",
19 | "from torch.utils.data import DataLoader\n",
20 | "import time\n",
21 | "\n",
22 | "from sklearn.preprocessing import LabelEncoder\n",
23 | "from sklearn.metrics import silhouette_score\n",
24 | "from utils.scores import purity_score as purity\n",
25 | "\n",
26 | "from src import KMeans, MiniBatchKMeans, KMeansDataset\n",
27 | "\n",
28 | "from sklearn.feature_extraction.text import TfidfVectorizer"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "2. Import data
"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 2,
41 | "metadata": {},
42 | "outputs": [
43 | {
44 | "name": "stdout",
45 | "output_type": "stream",
46 | "text": [
47 | "(2127, 29422)\n"
48 | ]
49 | }
50 | ],
51 | "source": [
52 | "df = pd.read_csv('bbc.csv')\n",
53 | "\n",
54 | "vectorizer = TfidfVectorizer()\n",
55 | "ds = vectorizer.fit_transform(df['content']).todense()\n",
56 | "classes = df['label'].to_numpy()\n",
57 | "\n",
58 | "idx = np.random.permutation(len(classes))\n",
59 | "ds = ds[idx, :]\n",
60 | "labels = LabelEncoder().fit_transform(np.asarray(classes))[idx]\n",
61 | "n_train = int(len(labels) * 0.7)\n",
62 | "\n",
63 | "print(ds.shape)\n",
64 | "\n",
65 | "X_train = torch.from_numpy(ds[:n_train, :])\n",
66 | "y_train = torch.from_numpy(labels[:n_train])\n",
67 | "X_test = torch.from_numpy(ds[n_train:, :])\n",
68 | "y_test = torch.from_numpy(labels[n_train:])\n",
69 | "\n",
70 | "if torch.cuda.is_available():\n",
71 | " X_train = X_train.cuda()\n",
72 | " y_train = y_train.cuda()\n",
73 | " X_test = X_test.cuda()\n",
74 | " y_test = y_test.cuda()"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 3,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "similarity_based = False\n",
84 | "batch_size = 64\n",
85 | "n_clusters = 5"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 4,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "dataset = KMeansDataset(X_train, similarity_based=similarity_based)\n",
95 | "dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=True)"
96 | ]
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "metadata": {},
101 | "source": [
102 | "3. K-Means
"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 5,
108 | "metadata": {
109 | "scrolled": true
110 | },
111 | "outputs": [
112 | {
113 | "name": "stdout",
114 | "output_type": "stream",
115 | "text": [
116 | "Clustering finished in 0.449 seconds.\n",
117 | "[Supervised Performance] Test Accuracy: 56.34 %\n"
118 | ]
119 | }
120 | ],
121 | "source": [
122 | "start = time.time()\n",
123 | "km = KMeans(n_clusters=n_clusters, n_init=1, init='random', similarity_based=similarity_based)\n",
124 | "km.fit(X_train)\n",
125 | "labels = km.transform(X_test)\n",
126 | "km_time = time.time() - start\n",
127 | "acc = purity(y_test, labels) * 100\n",
128 | "\n",
129 | "print(\"Clustering finished in {:.3} seconds.\".format(km_time))\n",
130 | "print(\"[Supervised Performance] Test Accuracy: {:.2f} %\".format(acc))"
131 | ]
132 | },
133 | {
134 | "cell_type": "markdown",
135 | "metadata": {},
136 | "source": [
137 | "4. Mini-Batch K-Means
"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 6,
143 | "metadata": {
144 | "scrolled": true
145 | },
146 | "outputs": [
147 | {
148 | "name": "stdout",
149 | "output_type": "stream",
150 | "text": [
151 | "Clustering finished in 11.4 seconds.\n",
152 | "[Supervised Performance] Test Accuracy: 73.40 %\n"
153 | ]
154 | }
155 | ],
156 | "source": [
157 | "start = time.time()\n",
158 | "km = MiniBatchKMeans(n_clusters=n_clusters, n_init=1, init='random', similarity_based=similarity_based)\n",
159 | "km.fit(dataloader)\n",
160 | "labels = km.transform_tensor(X_test)\n",
161 | "km_time = time.time() - start\n",
162 | "acc = purity(y_test, labels) * 100\n",
163 | "\n",
164 | "print(\"Clustering finished in {:.3} seconds.\".format(km_time))\n",
165 | "print(\"[Supervised Performance] Test Accuracy: {:.2f} %\".format(acc))"
166 | ]
167 | }
168 | ],
169 | "metadata": {
170 | "kernelspec": {
171 | "display_name": "Python 3",
172 | "language": "python",
173 | "name": "python3"
174 | },
175 | "language_info": {
176 | "codemirror_mode": {
177 | "name": "ipython",
178 | "version": 3
179 | },
180 | "file_extension": ".py",
181 | "mimetype": "text/x-python",
182 | "name": "python",
183 | "nbconvert_exporter": "python",
184 | "pygments_lexer": "ipython3",
185 | "version": "3.8.5"
186 | }
187 | },
188 | "nbformat": 4,
189 | "nbformat_minor": 2
190 | }
191 |
--------------------------------------------------------------------------------
/src/minibatchkmeans.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch-based K-Means
3 | by Ali Hassani
4 |
5 | MiniBatch K-Means
6 |
7 | Sculley, David. "Web-scale k-means clustering." Proceedings of the 19th international conference on
8 | World wide web. 2010. Manuscript available at: https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf
9 |
10 | NOTE: In the original mini-batch K-Means paper by David Sculley, the SGD updates were applied per sample
11 | in a mini-batch. In this implementation, in order to make things faster, the updates are applied per
12 | mini-batch by taking a sum over the mini-batch samples within the class, similar to the original K-Means
13 | algorithm. Another difference to the original method in the paper is that one mini-batch is generated
14 | randomly at each iteration by sampling the original dataset, while in this implementation, the entire
15 | dataset is shuffled, split into mini-batches all of which are processed at each iteration.
16 | """
17 | import numpy as np
18 | import torch
19 | from .kmeans import _BaseKMeans
20 | from ._kmeanspp import k_means_pp
21 |
22 |
23 | class MiniBatchKMeans(_BaseKMeans):
24 | """
25 | Mini Batch K-Means
26 |
27 | Parameters
28 | ----------
29 | n_clusters : int
30 | The number of clusters or `K`.
31 |
32 | init : 'random', 'k-means++' or torch.Tensor of shape (n_clusters, n_features)
33 | Tensor of the initial centroid coordinates, one of the pre-defined methods {'random'}.
34 |
35 | n_init : int, default=10
36 | Ignored (Number of initializations).
37 | NOTE: not yet supported.
38 |
39 | max_iter : int, default=200
40 | Maximum K-Means iterations.
41 |
42 | metric : 'default' or callable, default='default'
43 | Distance metric when similarity_based=False and similarity metric otherwise. Default is 'default'
44 | which uses L2 distance and cosine similarity as the distance and similarity metrics respectively.
45 | The callable metrics should take in two tensors of shapes (n, d) and (m, d) and return a tensor of
46 | shape (n, m).
47 |
48 | similarity_based : bool, default=False
49 | Whether the metric is a similarity metric or not.
50 |
51 | eps : float, default=1e-6
52 | Threshold for early stopping.
53 |
54 | Attributes
55 | ----------
56 | labels_ : torch.Tensor of shape (n_training_samples,)
57 | Training cluster assignments
58 |
59 | cluster_centers_ : torch.Tensor of shape (n_clusters, n_features)
60 | Final centroid coordinates
61 |
62 | inertia_ : float
63 | Sum of squared errors when not similarity_based and sum of similarities when similarity_based
64 |
65 | n_iter_ : int
66 | The number of training iterations
67 | """
68 | def __init__(self, n_clusters=None, init='k-means++', n_init=10, max_iter=200, metric='default',
69 | similarity_based=False, eps=1e-6):
70 | init = init if type(init) is torch.Tensor else 'k-means++'
71 | super(MiniBatchKMeans, self).__init__(n_clusters=n_clusters, init=init, max_iter=max_iter,
72 | metric=metric, similarity_based=similarity_based, eps=eps)
73 |
74 | def _initialize(self, dataloader):
75 | """
76 | Initializes the centroid coordinates.
77 |
78 | Parameters
79 | ----------
80 | dataloader : torch.utils.data.DataLoader[KMeansDataset]
81 |
82 | Returns
83 | -------
84 | self
85 | """
86 | self.labels_ = None
87 | self.inertia_ = 0
88 | self.n_iter_ = 0
89 | if self.init_method == 'k-means++':
90 | self._initialize_kpp(dataloader)
91 | elif self.init_method == 'random':
92 | self._initialize_random(dataloader)
93 | else:
94 | raise NotImplementedError("Initialization `{}` not supported.".format(self.cluster_centers_))
95 | return self
96 |
97 | def _initialize_random(self, dataloader):
98 | """
99 | Initializes the centroid coordinates by randomly selecting from the training samples.
100 |
101 | Parameters
102 | ----------
103 | dataloader : torch.utils.data.DataLoader[KMeansDataset]
104 |
105 | Returns
106 | -------
107 | self
108 | """
109 | if type(self.n_clusters) is not int:
110 | raise NotImplementedError("Randomized K-Means expects the number of clusters, given {}.".format(type(
111 | self.n_clusters)))
112 | # TODO: Better implementation of random initialization
113 | self.cluster_centers_, self.center_norm = dataloader.dataset.random_sample(self.n_clusters)
114 | self.center_norm = self._normalize(self.cluster_centers_)
115 | return self
116 |
117 | def _initialize_kpp(self, dataloader):
118 | """
119 | Initializes the centroid coordinates using K-Means++.
120 |
121 | Parameters
122 | ----------
123 | dataloader : torch.utils.data.DataLoader[KMeansDataset]
124 |
125 | Returns
126 | -------
127 | self
128 | """
129 | # TODO: Mini-batch K-Means++
130 | if type(self.n_clusters) is not int:
131 | raise NotImplementedError("K-Means++ expects the number of clusters, given {}.".format(type(
132 | self.n_clusters)))
133 | x, x_norm = next(iter(dataloader))
134 | self.cluster_centers_ = k_means_pp(x, n_clusters=self.n_clusters,
135 | x_norm=x_norm if not self.similarity_based else None)
136 | return self
137 |
138 | def fit(self, dataloader):
139 | """
140 | Initializes and fits the centroids using the samples given w.r.t the metric.
141 |
142 | Parameters
143 | ----------
144 | dataloader : torch.utils.data.DataLoader[KMeansDataset]
145 |
146 | Returns
147 | -------
148 | self
149 | """
150 | self._initialize(dataloader)
151 | self.inertia_ = None
152 | self.cluster_counts = np.zeros(self.n_clusters, dtype=int)
153 | for itr in range(self.max_iter):
154 | inertia = self._fit_iter(dataloader)
155 | if self.inertia_ is not None and abs(self.inertia_ - inertia) < self.eps:
156 | self.inertia_ = inertia
157 | break
158 | self.n_iter_ = itr + 1
159 | return self
160 |
161 | def _fit_iter(self, dataloader):
162 | """
163 | Performs one iteration of the mini-batch K-Means and updates the centroids.
164 |
165 | Parameters
166 | ----------
167 | dataloader : torch.utils.data.DataLoader[KMeansDataset]
168 |
169 | Returns
170 | -------
171 | self
172 | """
173 | # TODO: Cleaner and faster implementation
174 | inertia_ = 0
175 | for x, x_norm in dataloader:
176 | labels, inertia = self._assign(x, x_norm)
177 | inertia_ += inertia
178 | for c in range(self.n_clusters):
179 | idx = torch.where(labels == c)[0]
180 | self.cluster_counts[c] += len(idx)
181 | if len(idx) > 0:
182 | lr = 1 / self.cluster_counts[c]
183 | self.cluster_centers_[c, :] = ((1 - lr) * self.cluster_centers_[c, :]) + \
184 | (lr * torch.sum(torch.index_select(x, 0, idx), dim=0))
185 | self.center_norm = self._normalize(self.cluster_centers_)
186 | return inertia_
187 |
188 | def transform(self, dataloader):
189 | """
190 | Assigns the samples in the dataloader given to the clusters w.r.t the centroid coordinates
191 | and metric.
192 |
193 | Parameters
194 | ----------
195 | dataloader : torch.utils.data.DataLoader[KMeansDataset]
196 |
197 | Returns
198 | -------
199 | labels : torch.Tensor of shape (n_samples,)
200 | """
201 | label_list = []
202 | for x, x_norm in dataloader:
203 | labels, _ = self._assign(x, x_norm)
204 | label_list.append(labels)
205 | return torch.cat(label_list)
206 |
207 | def transform_tensor(self, x):
208 | """
209 | Assigns the samples given to the clusters w.r.t the centroid coordinates and metric.
210 |
211 | Parameters
212 | ----------
213 | x : torch.Tensor of shape (n_samples, n_features)
214 |
215 | Returns
216 | -------
217 | labels : torch.Tensor of shape (n_samples,)
218 | """
219 | labels, _ = self._assign(x)
220 | return labels
221 |
222 | def fit_transform(self, dataloader):
223 | """
224 | Fits the centroids using the samples given w.r.t the metric, returns the final assignments.
225 |
226 | Parameters
227 | ----------
228 | dataloader : torch.utils.data.DataLoader[KMeansDataset]
229 |
230 | Returns
231 | -------
232 | labels : torch.Tensor of shape (n_samples,)
233 | """
234 | self.fit(dataloader)
235 | return self.labels_
236 |
--------------------------------------------------------------------------------
/src/kmeans.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch-based K-Means
3 | by Ali Hassani
4 |
5 | K-Means
6 | """
7 |
8 | import random
9 | import numpy as np
10 | import torch
11 | from .utils import distance_matrix, similarity_matrix, squared_norm, row_norm
12 | from ._kmeanspp import k_means_pp
13 | from ._discern import discern
14 |
15 |
16 | class _BaseKMeans:
17 | """
18 | Base K-Means : DO NOT USE DIRECTLY
19 |
20 | Parameters
21 | ----------
22 | n_clusters : int or NoneType
23 | The number of clusters or `K`. Set to None ONLY when init = 'discern'.
24 |
25 | init : 'random', 'k-means++', 'discern', callable or torch.Tensor of shape (n_clusters, n_features)
26 | Tensor of the initial centroid coordinates, one of the pre-defined methods {'random', 'k-means++',
27 | 'discern'} or callable taking the training data as input and returning the centroid coordinates.
28 |
29 | n_init : int, default=10
30 | Ignored (Number of initializations).
31 | NOTE: not yet supported.
32 |
33 | max_iter : int, default=200
34 | Maximum K-Means iterations.
35 |
36 | metric : 'default' or callable, default='default'
37 | Distance metric when similarity_based=False and similarity metric otherwise. Default is 'default'
38 | which uses L2 distance and cosine similarity as the distance and similarity metrics respectively.
39 | WARNING: This metric does not apply to the pre-defined initialization methods (K-Means++ and DISCERN).
40 | The callable metrics should take in two tensors of shapes (n, d) and (m, d) and return a tensor of
41 | shape (n, m).
42 |
43 | similarity_based : bool, default=False
44 | Whether the metric is a similarity metric or not.
45 |
46 | eps : float, default=1e-6
47 | Threshold for early stopping.
48 |
49 | Attributes
50 | ----------
51 | labels_ : torch.Tensor of shape (n_training_samples,)
52 | Training cluster assignments
53 |
54 | cluster_centers_ : torch.Tensor of shape (n_clusters, n_features)
55 | Final centroid coordinates
56 |
57 | inertia_ : float
58 | Sum of squared errors when not similarity_based and sum of similarities when similarity_based
59 |
60 | n_iter_ : int
61 | The number of training iterations
62 | """
63 | def __init__(self, n_clusters=None, init='k-means++', n_init=10, max_iter=200, metric='default',
64 | similarity_based=False, eps=1e-6):
65 | self.n_clusters = n_clusters
66 | self.init_method = init if type(init) is str or callable(init) else 'k-means++'
67 | self.cluster_centers_ = init if type(init) is torch.Tensor else None
68 | self.max_iter = max_iter
69 | self.metric = metric if callable(metric) else 'default'
70 | self.similarity_based = similarity_based
71 | self.eps = eps
72 |
73 | self.center_norm = None
74 | self.labels_ = None
75 | self.inertia_ = 0
76 | self.n_iter_ = 0
77 |
78 | def _normalize(self, x):
79 | return row_norm(x) if self.similarity_based else squared_norm(x)
80 |
81 | def _assign(self, x, x_norm=None):
82 | """
83 | Takes a set of samples and assigns them to the clusters w.r.t the centroid coordinates and metric.
84 |
85 | Parameters
86 | ----------
87 | x : torch.Tensor of shape (n_samples, n_features)
88 | x_norm : torch.Tensor of shape (n_samples, ) or shape (n_samples, n_features), or NoneType
89 |
90 | Returns
91 | -------
92 | labels : torch.Tensor of shape (n_samples,)
93 | """
94 | if self.similarity_based:
95 | return self._similarity_based_assignment(x, x_norm)
96 | return self._distance_based_assignment(x, x_norm)
97 |
98 | def _distance_based_assignment(self, x, x_norm=None):
99 | """
100 | Takes a set of samples and assigns them using the metric to the clusters w.r.t the centroid coordinates.
101 |
102 | Parameters
103 | ----------
104 | x : torch.Tensor of shape (n_samples, n_features)
105 | x_norm : torch.Tensor of shape (n_samples, ) or NoneType
106 |
107 | Returns
108 | -------
109 | labels : torch.Tensor of shape (n_samples,)
110 | """
111 | if callable(self.metric):
112 | dist = self.metric(x, self.cluster_centers_)
113 | else:
114 | dist = distance_matrix(x, self.cluster_centers_, x_norm=x_norm, y_norm=self.center_norm)
115 | return torch.argmin(dist, dim=1), torch.sum(torch.min(dist, dim=1).values)
116 |
117 | def _similarity_based_assignment(self, x, x_norm):
118 | """
119 | Takes a set of samples and assigns them using the metric to the clusters w.r.t the centroid coordinates.
120 |
121 | Parameters
122 | ----------
123 | x : torch.Tensor of shape (n_samples, n_features)
124 | x_norm : torch.Tensor of shape (n_samples, n_features)
125 |
126 | Returns
127 | -------
128 | labels : torch.Tensor of shape (n_samples,)
129 | """
130 | if callable(self.metric):
131 | dist = self.metric(x, self.cluster_centers_)
132 | else:
133 | dist = similarity_matrix(x_norm if x_norm is not None else self._normalize(x),
134 | self.center_norm if self.center_norm is not None \
135 | else self._normalize(self.cluster_centers_),
136 | pre_normalized=True)
137 | return torch.argmax(dist, dim=1), torch.sum(torch.max(dist, dim=1).values)
138 |
139 |
140 | class KMeans(_BaseKMeans):
141 | """
142 | K-Means
143 | """
144 | def _initialize(self, x, x_norm):
145 | """
146 | Initializes the centroid coordinates.
147 |
148 | Parameters
149 | ----------
150 | x : torch.Tensor of shape (n_samples, n_features)
151 | x_norm : torch.Tensor of shape (n_samples, ) or shape (n_samples, n_features), or NoneType
152 |
153 | Returns
154 | -------
155 | self
156 | """
157 | self.labels_ = None
158 | self.inertia_ = 0
159 | self.n_iter_ = 0
160 | if callable(self.init_method):
161 | self.cluster_centers_ = self.init_method(x)
162 | self.n_clusters = self.cluster_centers_.size(0)
163 | elif self.init_method == 'k-means++':
164 | self._initialize_kpp(x, x_norm)
165 | elif self.init_method == 'discern':
166 | self._initialize_discern(x, x_norm)
167 | elif self.init_method == 'random':
168 | self._initialize_random(x)
169 | else:
170 | raise NotImplementedError("Initialization `{}` not supported.".format(self.cluster_centers_))
171 | self.center_norm = self._normalize(self.cluster_centers_)
172 | return self
173 |
174 | def _initialize_kpp(self, x, x_norm):
175 | """
176 | Initializes the centroid coordinates using K-Means++.
177 |
178 | Parameters
179 | ----------
180 | x : torch.Tensor of shape (n_samples, n_features)
181 | x_norm : torch.Tensor of shape (n_samples, ) or shape (n_samples, n_features), or NoneType
182 |
183 | Returns
184 | -------
185 | self
186 | """
187 | if type(self.n_clusters) is not int:
188 | raise NotImplementedError("K-Means++ expects the number of clusters, given {}.".format(type(
189 | self.n_clusters)))
190 | self.cluster_centers_ = k_means_pp(x, n_clusters=self.n_clusters,
191 | x_norm=x_norm if not self.similarity_based else None)
192 | return self
193 |
194 | def _initialize_discern(self, x, x_norm):
195 | """
196 | Initializes the centroid coordinates using DISCERN.
197 |
198 | Parameters
199 | ----------
200 | x : torch.Tensor of shape (n_samples, n_features)
201 | x_norm : torch.Tensor of shape (n_samples, ) or shape (n_samples, n_features), or NoneType
202 |
203 | Returns
204 | -------
205 | self
206 | """
207 | self.cluster_centers_ = discern(x, n_clusters=self.n_clusters,
208 | x_norm=x_norm if self.similarity_based else None)
209 | self.n_clusters = self.cluster_centers_.size(0)
210 | return self
211 |
212 | def _initialize_random(self, x):
213 | """
214 | Initializes the centroid coordinates by randomly selecting from the training samples.
215 |
216 | Parameters
217 | ----------
218 | x : torch.Tensor of shape (n_samples, n_features)
219 |
220 | Returns
221 | -------
222 | self
223 | """
224 | if type(self.n_clusters) is not int:
225 | raise NotImplementedError("Randomized K-Means expects the number of clusters, given {}.".format(type(
226 | self.n_clusters)))
227 | self.cluster_centers_ = x[random.sample(range(x.size(0)), self.n_clusters), :]
228 | return self
229 |
230 | def fit(self, x):
231 | """
232 | Initializes and fits the centroids using the samples given w.r.t the metric.
233 |
234 | Parameters
235 | ----------
236 | x : torch.Tensor of shape (n_samples, n_features)
237 |
238 | Returns
239 | -------
240 | self
241 | """
242 | x_norm = self._normalize(x)
243 | self._initialize(x, x_norm)
244 | self.inertia_ = None
245 | for itr in range(self.max_iter):
246 | labels, inertia = self._assign(x, x_norm)
247 | if self.inertia_ is not None and abs(self.inertia_ - inertia) < self.eps:
248 | self.labels_ = labels
249 | self.inertia_ = inertia
250 | break
251 | self.labels_ = labels
252 | self.inertia_ = inertia
253 | for c in range(self.n_clusters):
254 | idx = torch.where(labels == c)[0]
255 | self.cluster_centers_[c, :] = torch.mean(torch.index_select(x, 0, idx), dim=0)
256 | self.center_norm = self._normalize(self.cluster_centers_)
257 | self.n_iter_ = itr + 1
258 | return self
259 |
260 | def transform(self, x):
261 | """
262 | Assigns the samples given to the clusters w.r.t the centroid coordinates and metric.
263 |
264 | Parameters
265 | ----------
266 | x : torch.Tensor of shape (n_samples, n_features)
267 |
268 | Returns
269 | -------
270 | labels : torch.Tensor of shape (n_samples,)
271 | """
272 | labels, _ = self._assign(x, self._normalize(x))
273 | return labels
274 |
275 | def fit_transform(self, x):
276 | """
277 | Fits the centroids using the samples given w.r.t the metric, returns the final assignments.
278 |
279 | Parameters
280 | ----------
281 | x : torch.Tensor of shape (n_samples, n_features)
282 |
283 | Returns
284 | -------
285 | labels : torch.Tensor of shape (n_samples,)
286 | """
287 | self.fit(x)
288 | return self.labels_
289 |
--------------------------------------------------------------------------------