├── callbacks ├── callback.py ├── visualization_callback.py ├── timing_callback.py └── visualization_gif_callback.py ├── requirements.txt ├── configs ├── dataset │ ├── s4.yaml │ ├── jain.yaml │ ├── moons.yaml │ └── aggregation.yaml ├── algorithm │ ├── kmeans.yaml │ └── meanshift.yaml └── base.yaml ├── images └── meanshift.gif ├── download_datasets.sh ├── algorithms ├── base_algorithm.py ├── meanshift_base.py ├── kmeans_np.py ├── kmeans_tf_eager.py ├── kmeans_base.py ├── kmeans_pt.py ├── kmeans_tf.py ├── util_np.py ├── util_jax.py ├── meanshift_np.py ├── util_pt.py ├── meanshift_pt.py ├── kmeans_jax.py ├── meanshift_jax.py ├── meanshift_tf_eager.py ├── meanshift_tf.py └── builder.py ├── datasets ├── builder.py └── dataset.py ├── .gitignore ├── main.py ├── utils └── plotting.py └── README.md /callbacks/callback.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class DefaultCallback(): 4 | 5 | pass -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core 2 | matplotlib 3 | seaborn 4 | numpy 5 | scikit-learn 6 | imageio -------------------------------------------------------------------------------- /configs/dataset/s4.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | dataset: s4 3 | default_bandwidth: 0.07 4 | default_n_clusters: 7 -------------------------------------------------------------------------------- /configs/dataset/jain.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | dataset: jain 3 | default_bandwidth: 0.1 4 | default_n_clusters: 2 -------------------------------------------------------------------------------- /images/meanshift.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/creinders/ClusteringAlgorithmsFromScratch/HEAD/images/meanshift.gif -------------------------------------------------------------------------------- /configs/dataset/moons.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | dataset: 'moons' 3 | default_bandwidth: 0.2 4 | default_n_clusters: 2 -------------------------------------------------------------------------------- /configs/dataset/aggregation.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | dataset: aggregation 3 | default_bandwidth: 0.07 4 | default_n_clusters: 7 -------------------------------------------------------------------------------- /configs/algorithm/kmeans.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | algorithm_name: 'kmeans' 3 | n_clusters: null 4 | max_iter: 100 5 | early_stop_threshold: 0.00001 -------------------------------------------------------------------------------- /configs/algorithm/meanshift.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | algorithm_name: 'meanshift' 3 | bandwidth: null 4 | early_stop_threshold: 0.00001 5 | cluster_threshold: 0.0005 -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: aggregation 3 | - algorithm: meanshift 4 | - _self_ 5 | 6 | framework: 'jax' # numpy, pytorch, jax, tensorflow, tensorflow_eager 7 | cuda: true 8 | seed: 0 9 | verbose: true 10 | time: false 11 | time_repeats: 10 12 | plot: true 13 | plot_gif: false -------------------------------------------------------------------------------- /download_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | wget http://cs.joensuu.fi/sipu/datasets/Aggregation.txt -P data 4 | wget http://cs.joensuu.fi/sipu/datasets/jain.txt -P data 5 | wget http://cs.joensuu.fi/sipu/datasets/s4.txt -P data 6 | wget http://cs.joensuu.fi/sipu/datasets/s-originals.zip -P data 7 | 8 | unzip data/s-originals.zip -d data -------------------------------------------------------------------------------- /algorithms/base_algorithm.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class BaseAlgorithm: 4 | 5 | def __init__(self, callback=None, verbose=False) -> None: 6 | self.verbose = verbose 7 | self.callback = callback 8 | 9 | def get_hook(self, name): 10 | return getattr(self.callback, name, None) 11 | 12 | def call_hook(self, name, *args, **kwargs): 13 | callback_op = self.get_hook(name) 14 | if callable(callback_op): 15 | callback_op(*args, **kwargs) -------------------------------------------------------------------------------- /callbacks/visualization_callback.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from callbacks.callback import DefaultCallback 3 | from utils.plotting import plot_clustering 4 | 5 | class VisualizationCallback(DefaultCallback): 6 | 7 | def __init__(self, filename='visualization.png') -> None: 8 | self.filename = filename 9 | 10 | 11 | def on_epoch_end(self, method, X, clusters, assignments): 12 | X = method.tensor_to_numpy(X) 13 | 14 | plot_clustering(X, assignments=assignments, centers=clusters) 15 | plt.savefig(self.filename) 16 | 17 | -------------------------------------------------------------------------------- /callbacks/timing_callback.py: -------------------------------------------------------------------------------- 1 | from callbacks.callback import DefaultCallback 2 | import time 3 | 4 | class TimingCallback(DefaultCallback): 5 | 6 | def __init__(self) -> None: 7 | super().__init__() 8 | 9 | self.start = None 10 | self.n = 0 11 | self.total_duration = 0 12 | 13 | 14 | def on_main_loop_start(self): 15 | self.start = time.time() 16 | 17 | def on_main_loop_end(self): 18 | end = time.time() 19 | duration = end - self.start 20 | 21 | self.n += 1 22 | self.total_duration += duration 23 | -------------------------------------------------------------------------------- /datasets/builder.py: -------------------------------------------------------------------------------- 1 | from datasets.dataset import load_moons, normalize, load_text_data, load_s4 2 | 3 | def build_dataset(cfg): 4 | import os 5 | 6 | res_path = os.path.join(os.path.dirname(__file__), '../data') 7 | 8 | dataset_name = cfg.dataset 9 | 10 | if dataset_name == 'moons': 11 | X, _ = load_moons(500, random_state=cfg.seed) 12 | elif dataset_name == 'aggregation': 13 | X, _ = load_text_data(os.path.join(res_path, 'Aggregation.txt'), random_state=cfg.seed) 14 | elif dataset_name == 'jain': 15 | X, _ = load_text_data(os.path.join(res_path, 'jain.txt'), random_state=cfg.seed) 16 | elif dataset_name == 's4': 17 | X, _ = load_s4(random_state=cfg.seed) 18 | else: 19 | raise ValueError('unknown dataset: {}'.format(dataset_name)) 20 | # make_blobs(n_samples=300, centers=4, cluster_std=0.60, random_state=cfg.seed) 21 | X = normalize(X) 22 | 23 | return X 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | output 3 | outputs 4 | visualization.png 5 | visualization.gif 6 | 7 | .DS_Store 8 | .huskyrc.json 9 | out 10 | log.log 11 | **/node_modules 12 | *.pyc 13 | *.vsix 14 | **/.vscode/.ropeproject/** 15 | **/testFiles/**/.cache/** 16 | *.noseids 17 | .nyc_output 18 | .vscode-test 19 | __pycache__ 20 | npm-debug.log 21 | **/.mypy_cache/** 22 | !yarn.lock 23 | coverage/ 24 | cucumber-report.json 25 | **/.vscode-test/** 26 | **/.vscode test/** 27 | **/.vscode-smoke/** 28 | **/.venv*/ 29 | port.txt 30 | precommit.hook 31 | pythonFiles/lib/** 32 | debug_coverage*/** 33 | languageServer/** 34 | languageServer.*/** 35 | bin/** 36 | obj/** 37 | .pytest_cache 38 | tmp/** 39 | .python-version 40 | .vs/ 41 | test-results*.xml 42 | xunit-test-results.xml 43 | build/ci/performance/performance-results.json 44 | !build/ 45 | debug*.log 46 | debugpy*.log 47 | pydevd*.log 48 | nodeLanguageServer/** 49 | nodeLanguageServer.*/** 50 | dist/** 51 | # translation files 52 | *.xlf 53 | *.nls.*.json 54 | *.i18n.json -------------------------------------------------------------------------------- /algorithms/meanshift_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from algorithms.base_algorithm import BaseAlgorithm 3 | 4 | 5 | class MeanShiftBase(BaseAlgorithm): 6 | 7 | def __init__(self, bandwidth, early_stop_threshold=0.01, cluster_threshold=0.1, verbose=False, callback=None) -> None: 8 | super().__init__(callback=callback, verbose=verbose) 9 | self.bandwidth = bandwidth 10 | self.early_stop_threshold = early_stop_threshold 11 | self.cluster_threshold = cluster_threshold 12 | 13 | def prepare(self, X): 14 | return X, X.copy() 15 | 16 | def _main_loop(self, X, clusters): 17 | pass 18 | 19 | def tensor_to_numpy(self, t): 20 | return np.array(t) 21 | 22 | def finalize(self, centers, assignments): 23 | return self.tensor_to_numpy(centers), self.tensor_to_numpy(assignments) 24 | 25 | def fit(self, X): 26 | X, initial_centers = self.prepare(X) 27 | 28 | self.call_hook('on_main_loop_start') 29 | centers, assignments = self._main_loop(X, initial_centers) 30 | self.call_hook('on_main_loop_end') 31 | 32 | centers, assignments = self.finalize(centers, assignments) 33 | 34 | self.call_hook('on_epoch_end', self, X, centers, assignments) 35 | 36 | return centers, assignments 37 | -------------------------------------------------------------------------------- /algorithms/kmeans_np.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from algorithms.kmeans_base import KMeansBase 3 | from algorithms.util_np import scatter_mean0 4 | 5 | 6 | class KMeansNumpy(KMeansBase): 7 | def __init__(self, *args, **kwargs) -> None: 8 | super().__init__(*args, **kwargs) 9 | 10 | def _main_loop(self, X, centers): 11 | 12 | for iteration in range(self.max_iter): 13 | distance = np.sum(np.square((X[:, :, None] - np.transpose(centers)[None, ...])), axis=1) 14 | assignments = np.argmin(distance, axis=1) 15 | 16 | # k-Means can assign no points to a cluster center, in that case keep old value 17 | center_means, assigned_counts = scatter_mean0(X, assignments, axis_size=self.n_clusters, return_counts=True) 18 | new_centers = np.copy(centers) 19 | new_centers[assigned_counts > 0] = center_means[assigned_counts > 0] 20 | 21 | diff = np.sum(np.square(new_centers - centers)) 22 | 23 | if self.verbose: 24 | print('Iteration {}: {} difference'.format(iteration, diff)) 25 | 26 | if diff < self.early_stop_threshold: 27 | break 28 | centers = new_centers 29 | 30 | if self.get_hook('on_iteration_end'): 31 | self.call_hook('on_iteration_end', self, iteration, X, centers, assignments) 32 | 33 | return centers, assignments 34 | -------------------------------------------------------------------------------- /algorithms/kmeans_tf_eager.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from .kmeans_base import KMeansBase 5 | 6 | 7 | class KMeansTensorflowEager(KMeansBase): 8 | 9 | def __init__(self, *args, **kwargs) -> None: 10 | super().__init__(*args, **kwargs) 11 | 12 | 13 | def prepare(self, X): 14 | centers = self.init_clusters(X, self.n_clusters) 15 | centers = tf.Variable(centers) # K x C 16 | X = tf.convert_to_tensor(X) # B x C 17 | 18 | return X, centers 19 | 20 | def _main_loop(self, X, centers): 21 | 22 | for _ in range(self.max_iter): 23 | 24 | distance = tf.reduce_sum(tf.square((tf.expand_dims(X, axis=2) - tf.expand_dims(tf.transpose(centers, perm=(1, 0)), axis=0))), axis=1) 25 | assignments = tf.math.argmin(distance, axis=1) 26 | 27 | new_centers = [] 28 | for i in range(self.n_clusters): 29 | new_centers.append(tf.reduce_mean(X[assignments == i], axis=0)) 30 | 31 | new_centers = tf.stack(new_centers, axis=0) 32 | 33 | diff = tf.reduce_sum(tf.square((new_centers - centers))) 34 | centers = new_centers 35 | if diff < self.early_stop_threshold: 36 | break 37 | 38 | return centers, assignments 39 | 40 | def tensor_to_numpy(self, t): 41 | return np.array(t) 42 | 43 | -------------------------------------------------------------------------------- /algorithms/kmeans_base.py: -------------------------------------------------------------------------------- 1 | from tabnanny import verbose 2 | import numpy as np 3 | from .base_algorithm import BaseAlgorithm 4 | 5 | 6 | class KMeansBase(BaseAlgorithm): 7 | 8 | def __init__(self, n_clusters, max_iter=100, early_stop_threshold=0.01, seed = None, callback=None, verbose=False) -> None: 9 | super().__init__(callback=callback, verbose=verbose) 10 | self.n_clusters = n_clusters 11 | self.max_iter = max_iter 12 | self.early_stop_threshold = early_stop_threshold 13 | self.rng = np.random.RandomState(seed) 14 | 15 | def init_clusters(self, X, n_clusters): 16 | n = X.shape[0] 17 | 18 | i = self.rng.permutation(n)[:n_clusters] 19 | centers = X[i] 20 | return centers 21 | 22 | def prepare(self, X): 23 | centers = self.init_clusters(X, self.n_clusters) 24 | return X, centers 25 | 26 | def _main_loop(self, X, centers): 27 | pass 28 | 29 | def tensor_to_numpy(self, t): 30 | return np.array(t) 31 | 32 | def finalize(self, centers, assignments): 33 | return self.tensor_to_numpy(centers), self.tensor_to_numpy(assignments) 34 | 35 | def fit(self, X): 36 | X, initial_centers = self.prepare(X) 37 | 38 | self.call_hook('on_main_loop_start') 39 | centers, assignments = self._main_loop(X, initial_centers) 40 | self.call_hook('on_main_loop_end') 41 | 42 | centers, assignments = self.finalize(centers, assignments) 43 | 44 | self.call_hook('on_epoch_end', self, X, centers, assignments) 45 | 46 | return centers, assignments 47 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import numpy as np 3 | from callbacks.timing_callback import TimingCallback 4 | 5 | import hydra 6 | from omegaconf import DictConfig 7 | from datasets.builder import build_dataset 8 | from algorithms.builder import build_algorithm 9 | from callbacks.visualization_callback import VisualizationCallback 10 | from callbacks.visualization_gif_callback import VisualizationGifCallback 11 | 12 | 13 | @hydra.main(config_path="configs", config_name="base", version_base="1.2") 14 | def main(cfg: DictConfig) -> None: 15 | 16 | X = build_dataset(cfg) 17 | X = X.astype(np.float32) 18 | 19 | verbose = cfg.verbose 20 | n = 1 21 | 22 | if cfg.time: 23 | if cfg.plot_gif or cfg.plot: 24 | print('WARNING: cannot plot and time at the same time') 25 | callback = TimingCallback() 26 | verbose = False 27 | n = cfg.time_repeats 28 | 29 | elif cfg.plot_gif: 30 | if cfg.framework != 'numpy': 31 | print('WARNING: set framework to numpy for generating gifs') 32 | 33 | callback = VisualizationGifCallback(cfg.algorithm_name, plot_png=cfg.plot) 34 | 35 | elif cfg.plot: 36 | 37 | callback = VisualizationCallback() 38 | 39 | else: 40 | callback = None 41 | 42 | clustering = build_algorithm(cfg, callback=callback, verbose=verbose) 43 | for _ in range(n): 44 | clustering.fit(X.copy()) 45 | 46 | if cfg.time: 47 | print('{}s per iteration (n={})'.format(callback.total_duration / callback.n, callback.n)) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /algorithms/kmeans_pt.py: -------------------------------------------------------------------------------- 1 | from algorithms.kmeans_base import KMeansBase 2 | from algorithms.util_pt import scatter_mean0 3 | import torch 4 | 5 | 6 | class KMeansPytorch(KMeansBase): 7 | 8 | def __init__(self, *args, cuda=True, **kwargs) -> None: 9 | super().__init__(*args, **kwargs) 10 | self.cuda = cuda 11 | 12 | def prepare(self, X): 13 | centers = self.init_clusters(X, self.n_clusters) 14 | 15 | X = torch.from_numpy(X) 16 | centers = torch.from_numpy(centers) 17 | 18 | if self.cuda: 19 | X = X.cuda() 20 | centers = centers.cuda() 21 | 22 | return X, centers 23 | 24 | def _main_loop(self, X, centers): 25 | for iteration in range(self.max_iter): 26 | 27 | distance = (X[:, :, None] - centers.permute((1, 0))[None, ...]).square().sum(1) 28 | assignments = torch.argmin(distance, dim=1) 29 | 30 | # k-Means can assign no points to a cluster center, in that case keep old value 31 | center_means, assigned_counts = scatter_mean0(X, assignments, axis_size=self.n_clusters, return_counts=True) 32 | new_centers = torch.clone(centers) 33 | new_centers[assigned_counts > 0] = center_means[assigned_counts > 0] 34 | 35 | diff = (new_centers - centers).square().sum() 36 | if self.verbose: 37 | print('Iteration {}: {} difference'.format(iteration, diff)) 38 | 39 | if diff < self.early_stop_threshold: 40 | break 41 | 42 | centers = new_centers 43 | 44 | return centers, assignments 45 | 46 | def tensor_to_numpy(self, t): 47 | return t.cpu().detach().numpy() -------------------------------------------------------------------------------- /algorithms/kmeans_tf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from .kmeans_base import KMeansBase 5 | 6 | 7 | class KMeansTensorflow(KMeansBase): 8 | 9 | def __init__(self, *args, **kwargs) -> None: 10 | super().__init__(*args, **kwargs) 11 | 12 | 13 | def prepare(self, X): 14 | centers = self.init_clusters(X, self.n_clusters) 15 | centers = tf.Variable(centers) # K x C 16 | X = tf.convert_to_tensor(X) # B x C 17 | 18 | return X, centers 19 | 20 | def _main_loop(self, X, centers): 21 | 22 | tf.function 23 | def step(iteration, centers, assignments, diff): 24 | distance = tf.reduce_sum(tf.square((tf.expand_dims(X, axis=2) - tf.expand_dims(tf.transpose(centers, perm=(1, 0)), axis=0))), axis=1) 25 | assignments = tf.math.argmin(distance, axis=1) 26 | 27 | new_centers = [] 28 | for i in range(self.n_clusters): 29 | new_centers.append(tf.reduce_mean(X[assignments == i], axis=0)) 30 | 31 | new_centers = tf.stack(new_centers, axis=0) 32 | 33 | diff = tf.reduce_sum(tf.square((new_centers - centers))) 34 | 35 | return iteration + 1, new_centers, assignments, diff 36 | 37 | c = tf.constant(self.early_stop_threshold, dtype=X.dtype) 38 | 39 | tf.function 40 | def cond(iteration, centers, assignments, diff): 41 | return iteration < self.max_iter and tf.math.greater(diff, c) 42 | 43 | iteration, centers, assignments, diff = tf.while_loop(cond, step, [0, centers, None, 100]) 44 | return centers, assignments 45 | 46 | def tensor_to_numpy(self, t): 47 | return np.array(t) 48 | -------------------------------------------------------------------------------- /algorithms/util_np.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def connected_components_undirected(conn): 5 | """ 6 | Find connected components in undirected graph connectivity matrix 7 | """ 8 | 9 | assert conn.dtype == bool 10 | assert len(conn.shape) == 2 11 | assert np.all(conn == conn.T) 12 | 13 | result = np.full(len(conn), -1, dtype=int) 14 | 15 | curr_idx = 0 16 | 17 | while True: 18 | # Find next unassigned 19 | cand = np.flatnonzero(result == -1) 20 | 21 | if len(cand) == 0: 22 | break 23 | 24 | mask = np.zeros(len(conn), dtype=bool) 25 | mask[cand[0]] = True 26 | mask_sum = 1 27 | 28 | while True: 29 | mask = np.logical_or(mask, np.any(conn[mask], axis=0)) 30 | s = np.sum(mask) 31 | 32 | if s == mask_sum: 33 | break 34 | mask_sum = s 35 | 36 | assert np.all(result[mask] == -1) 37 | result[mask] = curr_idx 38 | 39 | curr_idx += 1 40 | 41 | assert np.all(result > -1) 42 | return curr_idx, result 43 | 44 | 45 | def scatter_mean0(src, index, axis_size=None, return_counts=False): 46 | """ 47 | Scatter mean on 0-th axis 48 | """ 49 | 50 | if axis_size is None: 51 | axis_size = np.max(index) + 1 52 | 53 | # Target shape is target size and remaining value shape without indexed dimension 54 | accumulator = np.zeros((axis_size,) + src.shape[1:], dtype=src.dtype) 55 | numerator = np.zeros(axis_size, dtype=int) 56 | 57 | np.add.at(accumulator, index, src) 58 | np.add.at(numerator, index, 1) 59 | 60 | result = accumulator / numerator.reshape(axis_size, *((1,) * len(src.shape[1:]))) 61 | 62 | if return_counts: 63 | return result, numerator 64 | else: 65 | return result 66 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def normalize(d, epsilon=1e-15): 5 | assert len(d.shape) == 2 6 | if not np.issubdtype(d.dtype, np.floating): 7 | raise ValueError("Invalid dtype: {}".format(d.dtype)) 8 | min = np.min(d, axis=0) 9 | max = np.max(d, axis=0) 10 | return (d - min) / np.maximum(max - min, epsilon) 11 | 12 | def load_text_data(path, shuffle=True, random_state=None): 13 | total = np.loadtxt(path) 14 | 15 | if shuffle: 16 | rdn = np.random.RandomState(random_state) 17 | rdn.shuffle(total) 18 | 19 | # Last column should be label, subtract minimum to always have 0 indexing 20 | total[:, -1] = total[:, -1] - np.min(total[:, -1]) 21 | x, y = total[..., :-1], total[..., -1].astype(int) 22 | 23 | return x, y 24 | 25 | 26 | def load_moons(n_total=2000, random_state=None): 27 | from sklearn.datasets import make_moons 28 | 29 | return make_moons(n_samples=n_total, noise=.05, random_state=random_state) 30 | 31 | 32 | def load_pa(path): 33 | with open(path, 'r') as f: 34 | lines = f.readlines() 35 | 36 | result = [] 37 | is_start = False 38 | 39 | for l in lines: 40 | if is_start: 41 | result.append(int(l)) 42 | elif l.startswith('---'): 43 | is_start = True 44 | 45 | return np.array(result) - np.min(result) 46 | 47 | 48 | def load_s4(n_total=1000, random_state=None): 49 | import os 50 | 51 | data_folder = os.path.join(os.path.dirname(__file__), '..', 'data') 52 | x = np.loadtxt(os.path.join(data_folder, 's4.txt')) 53 | y = load_pa(os.path.join(data_folder, 's4-label.pa')) 54 | assert len(x) == len(y) 55 | 56 | if n_total is not None: 57 | rdn = np.random.RandomState(random_state) 58 | idxs = rdn.choice(np.arange(len(x)), n_total) 59 | x, y = x[idxs], y[idxs] 60 | 61 | return x, y -------------------------------------------------------------------------------- /algorithms/util_jax.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import jit 4 | 5 | def connected_components_undirected(conn): 6 | """ 7 | Find connected components in undirected graph connectivity matrix 8 | """ 9 | 10 | assert conn.dtype == jnp.bool_ 11 | assert len(conn.shape) == 2 12 | assert jnp.all(conn == conn.T) 13 | 14 | result = jnp.full(conn.shape[:1], -1, dtype=jnp.int32) 15 | 16 | curr_idx = 0 17 | 18 | n = conn.shape[0] 19 | 20 | for i in range(n): 21 | # Find next unassigned 22 | 23 | if result[i] >= 0: 24 | continue 25 | 26 | mask = jnp.zeros(len(conn), dtype=jnp.bool_) 27 | mask = mask.at[i].set(True) 28 | mask_sum = 1 29 | 30 | while True: 31 | mask = jnp.logical_or(mask, jnp.any(conn[mask], axis=0)) 32 | s = jnp.sum(mask).item() 33 | 34 | if s == mask_sum: 35 | break 36 | mask_sum = s 37 | 38 | assert jnp.all(result[mask] == -1) 39 | result = result.at[mask].set(curr_idx) 40 | 41 | curr_idx += 1 42 | 43 | assert jnp.all(result > -1) 44 | return curr_idx, result 45 | 46 | 47 | @jit 48 | def calculate_mean(cluster_index, assignments, X): 49 | q = assignments == cluster_index 50 | mask = q.astype(jnp.int32) 51 | c = jnp.sum(mask) 52 | s = jnp.sum(X * mask[:, None], axis=0) 53 | m = s / c 54 | 55 | return m 56 | 57 | 58 | def scatter_mean0(src, index): 59 | """ 60 | Scatter mean on 0-th axis 61 | """ 62 | 63 | index_max = jnp.max(index) + 1 64 | # count = jnp.zeros(index_max, dtype=jnp.int32) 65 | # count = count.at[index].add(1) 66 | 67 | calculate_mean_vmap = jax.vmap(calculate_mean, in_axes=(0, None, None)) 68 | indices = jnp.arange(index_max) 69 | clusters = calculate_mean_vmap(indices, index, src) 70 | 71 | return clusters 72 | 73 | -------------------------------------------------------------------------------- /algorithms/meanshift_np.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from algorithms.meanshift_base import MeanShiftBase 3 | 4 | 5 | class MeanShiftNumpy(MeanShiftBase): 6 | 7 | def __init__(self, *args, **kwargs) -> None: 8 | super().__init__(*args, **kwargs) 9 | 10 | def prepare(self, X): 11 | return X, X.copy() 12 | 13 | def distance(self, a, b): 14 | # shape a: N x C 15 | # shape b: M x C 16 | 17 | d = np.sum(np.square(a[:, None, :] - b[None, :, :]), axis=-1) 18 | return d 19 | 20 | def kernel(self, distances): 21 | return np.exp(-0.5 * ((distances / self.bandwidth ** 2))) 22 | 23 | def _main_loop(self, X, clusters): 24 | 25 | iteration = 0 26 | while True: 27 | iteration += 1 28 | 29 | d = self.distance(clusters, X) 30 | w = self.kernel(d) 31 | 32 | new_centers = w[:, :, None] * X[None, :, :] 33 | w_sum = w.sum(1) 34 | new_centers = np.sum(new_centers, axis=1) / w_sum[:, None] 35 | 36 | diff = np.sum(np.square(new_centers - clusters)) 37 | 38 | if self.verbose: 39 | print('Iteration {}: {} difference'.format(iteration, diff.item())) 40 | 41 | if diff < self.early_stop_threshold: 42 | break 43 | clusters = new_centers 44 | 45 | if self.get_hook('on_iteration_end'): 46 | _, plot_assignments = self._group_clusters(clusters) 47 | self.call_hook('on_iteration_end', self, iteration, X, clusters, plot_assignments) 48 | 49 | clusters, assignments = self._group_clusters(clusters) 50 | return clusters, assignments 51 | 52 | def _group_clusters(self, points): 53 | from algorithms.util_np import connected_components_undirected, scatter_mean0 54 | 55 | _, cluster_ids = connected_components_undirected(self.distance(points, points) < self.cluster_threshold) 56 | cluster_centers = scatter_mean0(points, cluster_ids) 57 | return cluster_centers, cluster_ids 58 | 59 | -------------------------------------------------------------------------------- /algorithms/util_pt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def connected_components_undirected(conn): 5 | """ 6 | Find connected components in undirected graph connectivity matrix 7 | """ 8 | 9 | assert conn.dtype == torch.bool 10 | assert len(conn.shape) == 2 11 | assert torch.all(conn == conn.T) 12 | 13 | result = torch.full(conn.shape[:1], -1, dtype=torch.int64, device=conn.device) 14 | 15 | curr_idx = 0 16 | 17 | while True: 18 | # Find next unassigned 19 | cand = torch.nonzero(result == -1).flatten() 20 | 21 | if len(cand) == 0: 22 | break 23 | 24 | mask = torch.zeros(len(conn), dtype=torch.bool, device=conn.device) 25 | mask[cand[0]] = True 26 | mask_sum = 1 27 | 28 | while True: 29 | mask = torch.logical_or(mask, torch.any(conn[mask], dim=0)) 30 | s = torch.sum(mask).item() 31 | 32 | if s == mask_sum: 33 | break 34 | mask_sum = s 35 | 36 | assert torch.all(result[mask] == -1) 37 | result[mask] = curr_idx 38 | 39 | curr_idx += 1 40 | 41 | assert torch.all(result > -1) 42 | return curr_idx, result 43 | 44 | 45 | def scatter_mean0(src, index, axis_size=None, return_counts=False): 46 | """ 47 | Scatter mean on 0-th axis 48 | """ 49 | 50 | if axis_size is None: 51 | axis_size = torch.max(index) + 1 52 | 53 | # Target shape is target size and remaining value shape without indexed dimension 54 | accumulator = torch.zeros((axis_size,) + src.shape[1:], dtype=src.dtype, device=src.device) 55 | numerator = torch.zeros(axis_size, dtype=torch.int64, device=src.device) 56 | 57 | torch.index_put_(accumulator, (index,), src, accumulate=True) 58 | ones = torch.ones(len(index), dtype=torch.int64, device=numerator.device) 59 | torch.index_put_(numerator, (index,), ones, accumulate=True) 60 | 61 | result = accumulator / numerator.reshape(axis_size, *((1,) * len(src.shape[1:]))) 62 | 63 | if return_counts: 64 | return result, numerator 65 | else: 66 | return result 67 | -------------------------------------------------------------------------------- /algorithms/meanshift_pt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from algorithms.meanshift_base import MeanShiftBase 4 | 5 | 6 | class MeanShiftPytorch(MeanShiftBase): 7 | 8 | def __init__(self, cuda=True, *args, **kwargs) -> None: 9 | super().__init__(*args, **kwargs) 10 | 11 | self.cuda = cuda 12 | 13 | self.pi = torch.asin(torch.tensor(1.)) 14 | if cuda: 15 | self.pi = self.pi.cuda() 16 | 17 | def prepare(self, X): 18 | X = torch.from_numpy(X) 19 | 20 | if self.cuda: 21 | X = X.cuda() 22 | 23 | return X, X.clone() 24 | 25 | def distance(self, a, b): 26 | # shape a: N x C 27 | # shape b: M x C 28 | 29 | d = torch.sum(torch.square(a[:, None, :] - b[None, :, :]), dim=-1) 30 | return d 31 | 32 | def kernel(self, distances): 33 | return torch.exp(-0.5 * ((distances / self.bandwidth ** 2))) 34 | 35 | def _main_loop(self, X, clusters): 36 | 37 | iteration = 0 38 | while True: 39 | iteration += 1 40 | d = self.distance(clusters, X) 41 | w = self.kernel(d) 42 | 43 | new_centers = w[:, :, None] * X[None, :, :] 44 | w_sum = w.sum(1) 45 | new_centers = torch.sum(new_centers, dim=1) / w_sum[:, None] 46 | 47 | diff = torch.sum(torch.square(new_centers - clusters)) 48 | 49 | if self.verbose: 50 | print('Iteration {}: {} difference'.format(iteration, diff.item())) 51 | 52 | if diff < self.early_stop_threshold: 53 | break 54 | clusters = new_centers 55 | 56 | clusters, assignments = self._group_clusters(clusters) 57 | return clusters, assignments 58 | 59 | def _group_clusters(self, points): 60 | from algorithms.util_pt import connected_components_undirected, scatter_mean0 61 | 62 | _, cluster_ids = connected_components_undirected(self.distance(points, points) < self.cluster_threshold) 63 | cluster_centers = scatter_mean0(points, cluster_ids) 64 | return cluster_centers, cluster_ids 65 | 66 | def tensor_to_numpy(self, t): 67 | return t.cpu().detach().numpy() 68 | -------------------------------------------------------------------------------- /algorithms/kmeans_jax.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .kmeans_base import KMeansBase 3 | import jax 4 | import jax.numpy as jnp 5 | from jax import jit 6 | 7 | @jit 8 | def cluster_update(cluster_index, old_cluster, assignments, X): 9 | q = assignments == cluster_index 10 | mask = q.astype(jnp.int32) 11 | c = jnp.sum(mask) 12 | s = jnp.sum(X * mask[:, None], axis=0) 13 | m = s / c 14 | 15 | return m 16 | 17 | 18 | @jit 19 | def step(X, centers): 20 | cluster_update_vmap = jax.vmap(cluster_update, in_axes=(0, 0, None, None)) 21 | 22 | distance = jnp.sum(jnp.square((X[:, :, None] - jnp.transpose(centers, (1, 0))[None, ...])), axis=1) 23 | assignments = jnp.argmin(distance, axis=1) 24 | 25 | a = jnp.arange(centers.shape[0]) 26 | new_centers = cluster_update_vmap(a, centers, assignments, X) 27 | 28 | diff = jnp.sum(jnp.square((new_centers - centers))) 29 | return new_centers, diff, assignments 30 | 31 | 32 | class KMeansJax(KMeansBase): 33 | 34 | def __init__(self, *args, **kwargs) -> None: 35 | super().__init__(*args, **kwargs) 36 | 37 | # jax.config.update('jax_platform_name', 'cpu') 38 | 39 | def prepare(self, X): 40 | centers = self.init_clusters(X, self.n_clusters) 41 | centers = jnp.asarray(centers) # K x C 42 | X = jnp.asarray(X) # B x C 43 | 44 | return X, centers 45 | 46 | 47 | def _main_loop(self, X, centers): 48 | 49 | @jit 50 | def while_step(arg): 51 | iteration, centers, assignments, diff = arg 52 | new_centers, diff, assignments = step(X, centers) 53 | 54 | return (iteration + 1, new_centers, assignments, diff) 55 | 56 | @jit 57 | def cond(arg): 58 | iteration, centers, assignments, diff = arg 59 | return (iteration < self.max_iter) & (diff > self.early_stop_threshold) 60 | 61 | assignments = jnp.zeros(X.shape[0], dtype=jnp.int32) 62 | 63 | iteration, centers, assignments, diff = jax.lax.while_loop( 64 | cond, 65 | while_step, 66 | (0, centers, assignments, 1000) 67 | ) 68 | 69 | return centers, assignments 70 | 71 | def tensor_to_numpy(self, t): 72 | return np.array(t) 73 | -------------------------------------------------------------------------------- /algorithms/meanshift_jax.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import jax 4 | import jax.numpy as jnp 5 | from jax import jit 6 | from functools import partial 7 | 8 | from algorithms.meanshift_base import MeanShiftBase 9 | from algorithms.util_jax import connected_components_undirected, scatter_mean0 10 | 11 | class MeanShiftJax(MeanShiftBase): 12 | 13 | def __init__(self, cuda=True, *args, **kwargs) -> None: 14 | super().__init__(*args, **kwargs) 15 | 16 | if not cuda: 17 | jax.config.update('jax_platform_name', 'cpu') 18 | 19 | def prepare(self, X): 20 | X = jnp.asarray(X) 21 | return X, jnp.copy(X) 22 | 23 | @partial(jit, static_argnums=0) 24 | def distance(self, a, b): 25 | # shape a: N x C 26 | # shape b: M x C 27 | 28 | d = jnp.sum(jnp.square(a[:, None, :] - b[None, :, :]), axis=-1) 29 | return d 30 | 31 | @partial(jit, static_argnums=0) 32 | def kernel(self, distances): 33 | return jnp.exp(-0.5 * ((distances / self.bandwidth ** 2))) 34 | 35 | def _main_loop(self, X, clusters): 36 | 37 | @jit 38 | def step(arg): 39 | clusters, diff = arg 40 | d = self.distance(clusters, X) 41 | w = self.kernel(d) 42 | 43 | new_centers = w[:, :, None] * X[None, :, :] 44 | w_sum = w.sum(1) 45 | new_centers = jnp.sum(new_centers, axis=1) / w_sum[:, None] 46 | 47 | diff = jnp.sum(jnp.square(new_centers - clusters)) 48 | 49 | return (new_centers, diff) 50 | 51 | @jit 52 | def cond(arg): 53 | clusters, diff = arg 54 | return diff > self.early_stop_threshold 55 | 56 | clusters, diff = jax.lax.while_loop( 57 | cond, 58 | step, 59 | (clusters, 1000) 60 | ) 61 | 62 | clusters, assignments = self._group_clusters(clusters) 63 | return clusters, assignments 64 | 65 | def _group_clusters(self, points): 66 | _, cluster_ids = connected_components_undirected(self.distance(points, points) < self.cluster_threshold) 67 | cluster_centers = scatter_mean0(points, cluster_ids) 68 | return cluster_centers, cluster_ids 69 | 70 | def tensor_to_numpy(self, t): 71 | return np.array(t) 72 | -------------------------------------------------------------------------------- /callbacks/visualization_gif_callback.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from callbacks.callback import DefaultCallback 3 | from utils.plotting import plot_clustering 4 | import os 5 | import imageio 6 | 7 | class VisualizationGifCallback(DefaultCallback): 8 | 9 | def __init__(self, algorithm_name, filename='visualization.png', filename_gif='visualization.gif', plot_png=True, tmp_folder='output', fps=2, repeat_last_result=5) -> None: 10 | self.tmp_folder = tmp_folder 11 | self.filename_gif = filename_gif 12 | self.fps = fps 13 | self.repeat_last_result = repeat_last_result 14 | self.center_per_cluster = True if algorithm_name == 'meanshift' else False 15 | self.plot_png = plot_png 16 | self.filename = filename 17 | 18 | os.makedirs(self.tmp_folder, exist_ok=True) 19 | 20 | self.files = [] 21 | self.fig = plt.figure(figsize=(6, 6)) 22 | self.ax = self.fig.add_subplot(111) 23 | 24 | 25 | def on_iteration_end(self, method, iteration, X, clusters, assignments): 26 | 27 | plot_clustering(X, assignments=assignments, centers=clusters, center_per_point=self.center_per_cluster, fig=self.fig, ax=self.ax) 28 | plot_path = os.path.join(self.tmp_folder, 'iteration-{}.png'.format(iteration)) 29 | plt.savefig(plot_path) 30 | self.files.append(plot_path) 31 | 32 | def on_epoch_end(self, method, X, clusters, assignments): 33 | if len(self.files) == 0: 34 | print('no images for gif') 35 | return 36 | 37 | X = method.tensor_to_numpy(X) 38 | 39 | plot_clustering(X, assignments=assignments, centers=clusters, fig=self.fig, ax=self.ax) 40 | plot_path = os.path.join(self.tmp_folder, 'iteration-final.png') 41 | plt.savefig(plot_path) 42 | 43 | for _ in range(self.repeat_last_result): 44 | self.files.append(plot_path) 45 | 46 | if self.plot_png: 47 | plt.savefig(self.filename) 48 | 49 | print('creating gif') 50 | with imageio.get_writer(self.filename_gif, mode='I', fps=self.fps) as writer: 51 | for filename in self.files: 52 | image = imageio.imread(filename) 53 | writer.append_data(image) 54 | 55 | print('cleaning up temporary images') 56 | for filename in set(self.files): 57 | os.remove(filename) 58 | -------------------------------------------------------------------------------- /algorithms/meanshift_tf_eager.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import math 4 | 5 | from algorithms.meanshift_base import MeanShiftBase 6 | 7 | 8 | class MeanShiftTensorflowEager(MeanShiftBase): 9 | 10 | def __init__(self, cuda=True, *args, **kwargs) -> None: 11 | super().__init__(*args, **kwargs) 12 | 13 | self.cuda = cuda 14 | self.bandwidth = tf.constant(self.bandwidth) 15 | self.pi = tf.constant(math.pi) 16 | 17 | def prepare(self, X): 18 | X = tf.convert_to_tensor(X) 19 | return X, tf.identity(X) 20 | 21 | def distance(self, a, b): 22 | # shape a: N x C 23 | # shape b: M x C 24 | 25 | d = tf.reduce_sum(tf.square(tf.expand_dims(a, 1) - tf.expand_dims(b, 0)), axis=-1) 26 | return d 27 | 28 | def kernel(self, distances): 29 | return tf.exp(-0.5 * ((distances / self.bandwidth ** 2))) 30 | 31 | def _main_loop(self, X, clusters): 32 | 33 | iteration = 0 34 | while True: 35 | iteration += 1 36 | d = self.distance(clusters, X) 37 | w = self.kernel(d) 38 | 39 | new_centers = w[:, :, tf.newaxis] * X[tf.newaxis, :, :] 40 | w_sum = tf.reduce_sum(w, axis=1) 41 | new_centers = tf.reduce_sum(new_centers, axis=1) / w_sum[:, None] 42 | 43 | diff = tf.reduce_sum(tf.square(new_centers - clusters)) 44 | 45 | if self.verbose: 46 | print('Iteration {}: {:.5f} difference'.format(iteration, float(diff))) 47 | 48 | if diff < self.early_stop_threshold: 49 | break 50 | clusters = new_centers 51 | 52 | clusters, assignments = self._group_clusters(clusters) 53 | return clusters, assignments 54 | 55 | def _group_clusters(self, points): 56 | cluster_ids = [] 57 | cluster_centers = [] 58 | 59 | for point in points: 60 | add = True 61 | for cluster_index, cluster in enumerate(cluster_centers): 62 | dist = tf.reduce_sum(tf.square(point - cluster), axis=-1) 63 | if dist < self.cluster_threshold: 64 | cluster_ids.append(cluster_index) 65 | add = False 66 | break 67 | 68 | if add: 69 | cluster_ids.append(len(cluster_centers)) 70 | cluster_centers.append(point) 71 | 72 | cluster_centers = tf.stack(cluster_centers, axis=0) 73 | cluster_ids = tf.convert_to_tensor(cluster_ids) 74 | return cluster_centers, cluster_ids 75 | 76 | def tensor_to_numpy(self, t): 77 | return np.array(t) 78 | -------------------------------------------------------------------------------- /algorithms/meanshift_tf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import math 4 | 5 | from algorithms.meanshift_base import MeanShiftBase 6 | 7 | 8 | class MeanShiftTensorflow(MeanShiftBase): 9 | 10 | def __init__(self, cuda=True, *args, **kwargs) -> None: 11 | super().__init__(*args, **kwargs) 12 | 13 | self.cuda = cuda 14 | self.bandwidth = tf.constant(self.bandwidth) 15 | self.pi = tf.constant(math.pi) 16 | 17 | def prepare(self, X): 18 | X = tf.convert_to_tensor(X) 19 | return X, tf.identity(X) 20 | 21 | @tf.function 22 | def distance(self, a, b): 23 | # shape a: N x C 24 | # shape b: M x C 25 | 26 | d = tf.reduce_sum(tf.square(a[:, None, :] - b[None, :, :]), axis=-1) 27 | return d 28 | 29 | @tf.function 30 | def kernel(self, distances): 31 | return tf.exp(-0.5 * ((distances / self.bandwidth ** 2))) 32 | 33 | def _main_loop(self, X, clusters): 34 | 35 | @tf.function 36 | def step(clusters, diff): 37 | d = self.distance(clusters, X) 38 | w = self.kernel(d) 39 | 40 | new_centers = w[:, :, None] * X[None, :, :] 41 | w_sum = tf.reduce_sum(w, axis=1) 42 | new_centers = tf.reduce_sum(new_centers, axis=1) / w_sum[:, None] 43 | 44 | diff = tf.reduce_sum(tf.square(new_centers - clusters)) 45 | 46 | return new_centers, diff 47 | 48 | @tf.function 49 | def cond(clusters, diff): 50 | return diff > self.early_stop_threshold 51 | 52 | clusters, diff = tf.while_loop( 53 | cond, 54 | step, 55 | (clusters, 1000) 56 | ) 57 | 58 | clusters, assignments = self._group_clusters(clusters) 59 | return clusters, assignments 60 | 61 | def _group_clusters(self, points): 62 | cluster_ids = [] 63 | cluster_centers = [] 64 | 65 | for point in points: 66 | add = True 67 | for cluster_index, cluster in enumerate(cluster_centers): 68 | dist = tf.reduce_sum(tf.square(point - cluster), axis=-1) 69 | if dist < self.cluster_threshold: 70 | cluster_ids.append(cluster_index) 71 | add = False 72 | break 73 | 74 | if add: 75 | cluster_ids.append(len(cluster_centers)) 76 | cluster_centers.append(point) 77 | 78 | cluster_centers = tf.stack(cluster_centers, axis=0) 79 | cluster_ids = tf.convert_to_tensor(cluster_ids) 80 | return cluster_centers, cluster_ids 81 | 82 | def tensor_to_numpy(self, t): 83 | return np.array(t) 84 | -------------------------------------------------------------------------------- /utils/plotting.py: -------------------------------------------------------------------------------- 1 | 2 | def plot_clustering(points, assignments=None, centers=None, fig=None, ax=None, labels=None, 3 | alpha=1., center_size=1., center_marker='o', point_size=1., point_marker='x', 4 | palette=None, pfx=True, equal_axis_scale=True, center_per_point=False): 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from matplotlib import rcParams 8 | import seaborn as sns 9 | 10 | sns.set_style("white") 11 | 12 | if pfx: 13 | import matplotlib.patheffects as path_effects 14 | path_effects = [path_effects.withStroke(linewidth=2, foreground='black')] 15 | else: 16 | path_effects = None 17 | 18 | if assignments is None: 19 | assignments = np.zeros(len(points), dtype=int) 20 | 21 | assert len(points) == len(assignments) 22 | 23 | if len(points.shape) != 2 or points.shape[-1] != 2: 24 | raise ValueError("Invalid points shape: {}".format(points.shape)) 25 | 26 | def plot_wrap(fn, pts, *args, **kwargs): 27 | # Ensure always a 2D array even if just one point 28 | if len(pts.shape) == 1: 29 | pts = np.array([pts]) 30 | 31 | return fn(pts[:, 0], pts[:, 1], *args, **kwargs) 32 | 33 | if centers is not None: 34 | n_centers = len(centers) 35 | assert np.max(assignments) < n_centers 36 | else: 37 | n_centers = np.max(assignments) + 1 38 | 39 | if palette is None: 40 | colors = plt.get_cmap('tab10').colors 41 | palette = [colors[i % 10] for i in range(n_centers)] 42 | 43 | unique = np.unique(assignments) 44 | 45 | if fig is None: 46 | fig = plt.figure(figsize=(6, 6)) 47 | 48 | 49 | 50 | if ax is None: 51 | ax = fig.add_subplot(111) 52 | 53 | ax.cla() 54 | 55 | for u in unique: 56 | label = labels[u] if labels is not None else None 57 | 58 | p_color = palette[u] 59 | p_points = points[assignments == u] 60 | plot_wrap(ax.scatter, p_points, marker=point_marker, s=point_size * rcParams['lines.markersize'] ** 2, 61 | color=p_color, label=label, alpha=alpha, path_effects=path_effects) 62 | 63 | if centers is not None: 64 | if center_per_point: 65 | p_cluster = centers[assignments == u, :] 66 | else: 67 | p_cluster = centers[u] 68 | 69 | plot_wrap(ax.scatter, p_cluster, marker=center_marker, color=palette[u], 70 | s=center_size * rcParams['lines.markersize'] ** 2, alpha=alpha, path_effects=path_effects) 71 | 72 | if equal_axis_scale: 73 | # Ensure X and Y axis have equal scale in visualization 74 | plt.gca().set_aspect('equal', adjustable='box') 75 | 76 | ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) 77 | 78 | 79 | plt.tight_layout() 80 | return ax -------------------------------------------------------------------------------- /algorithms/builder.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | def build_algorithm(cfg, callback=None, verbose=False): 4 | algorithm_name = cfg.algorithm_name 5 | framework_name = cfg.framework 6 | 7 | if algorithm_name == 'kmeans': 8 | 9 | if framework_name == 'numpy': 10 | from algorithms.kmeans_np import KMeansNumpy 11 | clustering_class = KMeansNumpy 12 | elif framework_name == 'pytorch': 13 | from algorithms.kmeans_pt import KMeansPytorch 14 | clustering_class = partial(KMeansPytorch, cuda=cfg.cuda) 15 | elif framework_name == 'jax': 16 | from algorithms.kmeans_jax import KMeansJax 17 | clustering_class = KMeansJax 18 | elif framework_name == 'tensorflow': 19 | from algorithms.kmeans_tf import KMeansTensorflow 20 | clustering_class = KMeansTensorflow 21 | elif framework_name == 'tensorflow_eager': 22 | from algorithms.kmeans_tf_eager import KMeansTensorflowEager 23 | clustering_class = KMeansTensorflowEager 24 | else: 25 | raise ValueError('Unknown framework or not implemented: {}'.format(framework_name)) 26 | 27 | n_clusters = cfg.n_clusters if cfg.n_clusters is not None else cfg.default_n_clusters 28 | clustering = clustering_class(n_clusters=n_clusters, max_iter=cfg.max_iter, early_stop_threshold=cfg.early_stop_threshold, verbose=verbose, callback=callback, seed=cfg.seed) 29 | 30 | elif algorithm_name == 'meanshift': 31 | 32 | if framework_name == 'numpy': 33 | from algorithms.meanshift_np import MeanShiftNumpy 34 | clustering_class = MeanShiftNumpy 35 | elif framework_name == 'pytorch': 36 | from algorithms.meanshift_pt import MeanShiftPytorch 37 | clustering_class = partial(MeanShiftPytorch, cuda=cfg.cuda) 38 | elif framework_name == 'jax': 39 | from algorithms.meanshift_jax import MeanShiftJax 40 | clustering_class = MeanShiftJax 41 | elif framework_name == 'tensorflow': 42 | from algorithms.meanshift_tf import MeanShiftTensorflow 43 | clustering_class = MeanShiftTensorflow 44 | elif framework_name == 'tensorflow_eager': 45 | from algorithms.meanshift_tf_eager import MeanShiftTensorflowEager 46 | clustering_class = MeanShiftTensorflowEager 47 | else: 48 | raise ValueError('Unknown framework or not implemented: {}'.format(framework_name)) 49 | 50 | bandwidth = cfg.bandwidth if cfg.bandwidth is not None else cfg.default_bandwidth 51 | clustering = clustering_class(bandwidth=bandwidth, early_stop_threshold=cfg.early_stop_threshold, cluster_threshold=cfg.cluster_threshold, verbose=verbose, callback=callback) 52 | 53 | else: 54 | raise ValueError('Unknown algorithm: {}'.format(algorithm_name)) 55 | 56 | 57 | return clustering -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Clustering 2 | 3 | 4 | Implementation of unsupervised clustering algorithms from scratch in different machine learning frameworks. The goal is to demonstrate the similarities and differences of the frameworks. 5 | If you have an idea to improve an implementation (e.g., a more elegant or faster solution) or would like to implement a different algorithm/framework, please feel free to contribute. 6 | 7 | Clustering algorithms 8 | - K-Means 9 | - Mean shift 10 | 11 | Machine learning frameworks 12 | - [NumPy](https://numpy.org) 13 | - [PyTorch](https://pytorch.org) 14 | - [TensorFlow (Eager and Graph Mode)](https://www.tensorflow.org) 15 | - [JAX](https://jax.readthedocs.io/) 16 | 17 | 18 | ## Algorithms 19 | 20 | | Algorithm | Framework | | 21 | | :--------- | :------ | :------ | 22 | | K-Means | NumPy | [kmeans_np.py](algorithms/kmeans_np.py) | 23 | | | PyTorch | [kmeans_pt.py](algorithms/kmeans_pt.py) | 24 | | | TensorFlow 2 (Eager) | [kmeans_tf_eager.py](algorithms/kmeans_tf_eager.py) | 25 | | | TensorFlow 2 (Graph) | [kmeans_tf.py](algorithms/kmeans_tf.py) | 26 | | | JAX | [kmeans_jax.py](algorithms/kmeans_jax.py) | 27 | | Mean shift | NumPy | [meanshift_np.py](algorithms/meanshift_np.py) | 28 | | | PyTorch | [meanshift_pt.py](algorithms/meanshift_pt.py) | 29 | | | JAX | [meanshift_jax.py](algorithms/meanshift_jax.py) | 30 | | | TensorFlow 2 (Eager) | [meanshift_tf_eager.py](algorithms/meanshift_tf_eager.py) | 31 | | | TensorFlow 2 (Graph) | [meanshift_tf.py](algorithms/meanshift_tf.py) | 32 | 33 | 34 | ![Mean shift on Aggregation](images/meanshift.gif) 35 | 36 | ## Usage 37 | 38 | Please follow the [installation guide](#installation). 39 | 40 | You can simply run the following command to execute a `mean shift clustering` on `aggregation` with `JAX` 41 | ``` 42 | python main.py 43 | ``` 44 | To select different algorithms, datasets, or frameworks, r 45 | 46 | The algorithm, dataset, and framework can be selected via command like options, set 47 | - `algorithm` to `kmeans` or `meanshift` 48 | - `dataset` to `aggregation`, `jain`, `moons`, `s4`, or `meanshift` 49 | - `framework` to `numpy`, `pytorch`, `jax`, `tensorflow_eager`, or `tensorflow` 50 | 51 | For example 52 | ``` 53 | python main.py algorithm=kmeans dataset=moons framework=pytorch 54 | ``` 55 | For all options, please see `configs/base.yaml`. 56 | 57 | For timing, set `time=true` 58 | ``` 59 | python main.py time=true 60 | ``` 61 | Plot result 62 | ``` 63 | python main.py plot=true 64 | ``` 65 | 66 | Plot gif 67 | ``` 68 | python main.py plot_gif=true 69 | ``` 70 | 71 | ## Installation 72 | 73 | Clone repository 74 | ```bash 75 | git clone git@github.com:creinders/ClusteringAlgorithmsFromScratch.git 76 | cd ClusteringAlgorithmsFromScratch 77 | ``` 78 | 79 | Install anaconda environment and dependencies 80 | ``` 81 | conda create -n clustering python=3.9 82 | conda activate clustering 83 | 84 | # Install PyTorch (follow https://pytorch.org/get-started) 85 | conda install pytorch -c pytorch 86 | 87 | # Install TensorFlow (follow https://www.tensorflow.org/install/pip) 88 | pip install tensorflow 89 | 90 | # Install JAX (follow https://github.com/google/jax#installation) 91 | pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 92 | 93 | pip install -r requirements.txt 94 | ``` 95 | 96 | If you want to use the datasets `aggregation`, `jain`, or `s4`, please download the data 97 | ``` 98 | ./download_datasets.sh 99 | ``` 100 | --------------------------------------------------------------------------------