├── requirements.txt ├── FMS-Euc-git ├── meanshift │ ├── __pycache__ │ │ ├── batch_seed.cpython-38.pyc │ │ └── mean_shift_gpu.cpython-38.pyc │ ├── batch_seed.py │ └── mean_shift_gpu.py ├── FMS-Euc.sln ├── main.py └── FMS-Euc.pyproj └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.19.5 2 | sklearn>=0.24.1 3 | torch>=1.8.1+cu111 4 | -------------------------------------------------------------------------------- /FMS-Euc-git/meanshift/__pycache__/batch_seed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masqm/Faster-Mean-Shift-Euc/HEAD/FMS-Euc-git/meanshift/__pycache__/batch_seed.cpython-38.pyc -------------------------------------------------------------------------------- /FMS-Euc-git/meanshift/__pycache__/mean_shift_gpu.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masqm/Faster-Mean-Shift-Euc/HEAD/FMS-Euc-git/meanshift/__pycache__/mean_shift_gpu.cpython-38.pyc -------------------------------------------------------------------------------- /FMS-Euc-git/FMS-Euc.sln: -------------------------------------------------------------------------------- 1 | 2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio Version 16 4 | VisualStudioVersion = 16.0.31205.134 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "FMS-Euc", "FMS-Euc.pyproj", "{4F293353-3B33-42A6-9A0C-C959F8D39DD7}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|Any CPU = Debug|Any CPU 11 | Release|Any CPU = Release|Any CPU 12 | EndGlobalSection 13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 14 | {4F293353-3B33-42A6-9A0C-C959F8D39DD7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 15 | {4F293353-3B33-42A6-9A0C-C959F8D39DD7}.Release|Any CPU.ActiveCfg = Release|Any CPU 16 | EndGlobalSection 17 | GlobalSection(SolutionProperties) = preSolution 18 | HideSolutionNode = FALSE 19 | EndGlobalSection 20 | GlobalSection(ExtensibilityGlobals) = postSolution 21 | SolutionGuid = {C00CD930-BF0E-44D7-A474-D9EB35BCC5DA} 22 | EndGlobalSection 23 | EndGlobal 24 | -------------------------------------------------------------------------------- /FMS-Euc-git/main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | #from sklearn.cluster import MeanShift 6 | from sklearn import datasets, cluster 7 | from sklearn.preprocessing import StandardScaler 8 | 9 | from meanshift.mean_shift_gpu import MeanShiftEuc 10 | 11 | def main(): 12 | # Generate a blob dataset. 13 | n_samples = 100000 14 | blobs = datasets.make_blobs(n_samples=n_samples, random_state=9) 15 | 16 | # Normalize dataset for easier parameter selection 17 | X, y = blobs 18 | X = StandardScaler().fit_transform(X) 19 | 20 | # Estimate bandwidth for mean shift(Select 1000 points) 21 | bandwidth = cluster.estimate_bandwidth(X[0:999]) 22 | bandwidth_gpu = 2*bandwidth/(X.max()-X.min()) 23 | 24 | # Obtain results 25 | ms = MeanShiftEuc(bandwidth=bandwidth_gpu, cluster_all=True, GPU=True) 26 | ms.fit(X) 27 | labels = ms.labels_ 28 | 29 | return 0 30 | 31 | if __name__ == "__main__": 32 | # execute only if run as a script 33 | main() 34 | -------------------------------------------------------------------------------- /FMS-Euc-git/FMS-Euc.pyproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Debug 5 | 2.0 6 | {4f293353-3b33-42a6-9a0c-c959f8d39dd7} 7 | 8 | main.py 9 | meanshift\ 10 | . 11 | . 12 | {888888a0-9f3d-457c-b088-3a5042f75d52} 13 | Standard Python launcher 14 | Global|ContinuumAnalytics|Anaconda38-64 15 | 16 | 17 | 18 | 19 | 10.0 20 | 21 | 22 | 23 | Code 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /FMS-Euc-git/meanshift/batch_seed.py: -------------------------------------------------------------------------------- 1 | # Author Mengyang Zhao 2 | 3 | import math 4 | import operator 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | from torch import exp, sqrt 11 | 12 | def euc_batch(a, b): 13 | result = sqrt(((b[None,:] - a[:,None]) ** 2).sum(2)) 14 | #pdist = torch.nn.PairwiseDistance(p=2) 15 | #result = pdist(a, b) 16 | return result 17 | 18 | #num = a@b.T 19 | #denom = torch.norm(a, dim=1).reshape(-1, 1) * torch.norm(b, dim=1) 20 | #return num / denom 21 | 22 | def get_weight(sim, bandwidth): 23 | 24 | thr = 1-bandwidth 25 | #max = torch.tensor(1.0e+10).double().cuda() 26 | max = torch.tensor(1.0).double().cuda() 27 | min = torch.tensor(0.0).double().cuda() 28 | #dis=torch.where(sim>thr, 1-sim, max) 29 | dis=torch.where(sim>thr, max, min) 30 | 31 | return dis 32 | 33 | def gaussian(dist, bandwidth): 34 | return exp(-0.5 * ((dist / bandwidth))**2) / (bandwidth * math.sqrt(2 * math.pi)) 35 | 36 | def meanshift_torch(data, seed , bandwidth, max_iter=300): 37 | 38 | stop_thresh = 1e-3 * bandwidth 39 | iter=0 40 | 41 | X = torch.from_numpy(np.copy(data)).double().cuda() 42 | S = torch.from_numpy(np.copy(seed)).double().cuda() 43 | B = torch.tensor(bandwidth).double().cuda() 44 | 45 | while True: 46 | #cosine = cos_batch(S, X) 47 | 48 | weight = gaussian(euc_batch(S, X),B) 49 | 50 | #torch.where(distances>(1-bandwidth)) 51 | #weight = gaussian(distances, B) 52 | num = (weight[:, :, None] * X).sum(dim=1) 53 | S_old = S 54 | S = num / weight.sum(1)[:, None] 55 | #cosine2 = torch.norm(S - S_old, dim=1).mean() 56 | iter+=1 57 | 58 | if (torch.norm(S - S_old, dim=1).mean() < stop_thresh or iter == max_iter): 59 | break 60 | 61 | p_num=[] 62 | for line in weight: 63 | p_num.append(line[line==1].size()[0]) 64 | 65 | my_mean = S.cpu().numpy() 66 | 67 | return my_mean, p_num 68 | 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Faster-Mean-Shift-Euc 2 | Faster Mean-shift algorithm with Euclidean Distance Metrics. The algorithm is based on GPU acceleration, which can achieve satisfactory speedup with optimized GPU memory consumption. Here is a brief introduction to how to run it. And the details of our algorithm, please refer to our [paper1](https://arxiv.org/abs/2112.13891), [paper2](https://doi.org/10.1016/j.media.2021.102048). 3 | 4 | The Cosine Metrics version is provided in another repository [Faster-Mean-Shift](https://github.com/masqm/Faster-Mean-Shift). 5 | 6 | 7 | ## Environment 8 | Win10 9 | 10 | VS2019 11 | 12 | Anacoda 2020.11 13 | 14 | The packages requirement please see [requirements.txt](https://github.com/masqm/Faster-Mean-Shift-Euc/blob/main/requirements.txt) 15 | 16 | Please make sure to use a compatible pytorch-gpu version. 17 | 18 | ## Example 19 | Using our algorithm is similar to calling the meanshift in sklearn. 20 | An example of how to run the algorithm is given in [main.py](https://github.com/masqm/Faster-Mean-Shift-Euc/blob/main/FMS-Euc-git/main.py): 21 | 22 | # Generate a blob dataset. 23 | n_samples = 100000 24 | blobs = datasets.make_blobs(n_samples=n_samples, random_state=9) 25 | 26 | # Normalize dataset for easier parameter selection 27 | X, y = blobs 28 | X = StandardScaler().fit_transform(X) 29 | 30 | # Estimate bandwidth for mean shift(Select 1000 points) 31 | bandwidth = cluster.estimate_bandwidth(X[0:999]) 32 | bandwidth_gpu = 2*bandwidth/(X.max()-X.min()) 33 | 34 | # Obtain results 35 | ms = MeanShiftEuc(bandwidth=bandwidth_gpu, cluster_all=True, GPU=True) 36 | ms.fit(X) 37 | labels = ms.labels_ 38 | 39 | Here is a helpful Q&A for image segmentation and bandwidth setting: https://github.com/masqm/Faster-Mean-Shift-Euc/issues/1 40 | 41 | If you encounter any problem or find a bug during using, you are very welcome to contact me by (Mengyang.Zhao.TH@dartmouth.edu). If you use this code for your research, please cite our [paper1](https://arxiv.org/abs/2112.13891), [paper2](https://doi.org/10.1016/j.media.2021.102048). Thanks! 42 | -------------------------------------------------------------------------------- /FMS-Euc-git/meanshift/mean_shift_gpu.py: -------------------------------------------------------------------------------- 1 | """Mean shift clustering algorithm. 2 | 3 | Mean shift clustering aims to discover *blobs* in a smooth density of 4 | samples. It is a centroid based algorithm, which works by updating candidates 5 | for centroids to be the mean of the points within a given region. These 6 | candidates are then filtered in a post-processing stage to eliminate 7 | near-duplicates to form the final set of centroids. 8 | 9 | Seeding is performed using a binning technique for scalability. 10 | """ 11 | 12 | # Author Mengyang Zhao 13 | 14 | # Based on: Conrad Lee 15 | # Alexandre Gramfort 16 | # Gael Varoquaux 17 | # Martino Sorbaro 18 | 19 | import numpy as np 20 | import warnings 21 | import math 22 | 23 | from collections import defaultdict 24 | #from sklearn.externals import six 25 | from sklearn.utils.validation import check_is_fitted 26 | from sklearn.utils import check_random_state, gen_batches, check_array 27 | from sklearn.base import BaseEstimator, ClusterMixin 28 | from sklearn.neighbors import NearestNeighbors 29 | from sklearn.metrics.pairwise import pairwise_distances_argmin 30 | from joblib import Parallel 31 | from joblib import delayed 32 | 33 | from meanshift.batch_seed import meanshift_torch 34 | from random import shuffle 35 | 36 | 37 | #seeds number intital 38 | SEED_NUM = 128 39 | L=8 40 | H=32 41 | 42 | 43 | def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0, n_jobs=None): 44 | """Estimate the bandwidth to use with the mean-shift algorithm. 45 | 46 | That this function takes time at least quadratic in n_samples. For large 47 | datasets, it's wise to set that parameter to a small value. 48 | 49 | Parameters 50 | ---------- 51 | X : array-like, shape=[n_samples, n_features] 52 | Input points. 53 | 54 | quantile : float, default 0.3 55 | should be between [0, 1] 56 | 0.5 means that the median of all pairwise distances is used. 57 | 58 | n_samples : int, optional 59 | The number of samples to use. If not given, all samples are used. 60 | 61 | random_state : int, RandomState instance or None (default) 62 | The generator used to randomly select the samples from input points 63 | for bandwidth estimation. Use an int to make the randomness 64 | deterministic. 65 | See :term:`Glossary `. 66 | 67 | n_jobs : int or None, optional (default=None) 68 | The number of parallel jobs to run for neighbors search. 69 | ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. 70 | ``-1`` means using all processors. See :term:`Glossary ` 71 | for more details. 72 | 73 | Returns 74 | ------- 75 | bandwidth : float 76 | The bandwidth parameter. 77 | """ 78 | X = check_array(X) 79 | 80 | random_state = check_random_state(random_state) 81 | if n_samples is not None: 82 | idx = random_state.permutation(X.shape[0])[:n_samples] 83 | X = X[idx] 84 | n_neighbors = int(X.shape[0] * quantile) 85 | if n_neighbors < 1: # cannot fit NearestNeighbors with n_neighbors = 0 86 | n_neighbors = 1 87 | nbrs = NearestNeighbors(n_neighbors=n_neighbors, 88 | n_jobs=n_jobs) 89 | nbrs.fit(X) 90 | 91 | bandwidth = 0. 92 | for batch in gen_batches(len(X), 500): 93 | d, _ = nbrs.kneighbors(X[batch, :], return_distance=True) 94 | bandwidth += np.max(d, axis=1).sum() 95 | 96 | return bandwidth / X.shape[0] 97 | 98 | 99 | def gpu_seed_generator(codes): 100 | 101 | seed_indizes = list(range(codes.shape[0])) 102 | shuffle(seed_indizes) 103 | seed_indizes = seed_indizes[:SEED_NUM] 104 | seeds = codes[seed_indizes] 105 | 106 | return seeds 107 | 108 | def gpu_seed_adjust(codes): 109 | global SEED_NUM 110 | SEED_NUM *= 2 111 | 112 | return gpu_seed_generator(codes) 113 | 114 | def get_N(P,r,I): 115 | 116 | #There is no foreground instances 117 | if r<0.1: 118 | return 32 #Allocated some seeds at least 119 | 120 | lnp = math.log(P,math.e) 121 | num=math.log(1-math.e**(lnp/I),math.e) 122 | den = math.log(1-r/I,math.e) 123 | result = num/den 124 | 125 | if result<32: 126 | result =32 #Allocated some seeds at least 127 | elif result>256: 128 | result =256 #Our GPU memory's max limitation, you can higher it. 129 | 130 | return int(result) 131 | 132 | 133 | def mean_shift_euc(X, bandwidth=None, seeds=None, 134 | cluster_all=True, GPU=True): 135 | """Perform mean shift clustering of data using a flat kernel. 136 | 137 | Read more in the :ref:`User Guide `. 138 | 139 | Parameters 140 | ---------- 141 | 142 | X : array-like, shape=[n_samples, n_features] 143 | Input data. 144 | 145 | bandwidth : float, optional 146 | Kernel bandwidth. 147 | 148 | If bandwidth is not given, it is determined using a heuristic based on 149 | the median of all pairwise distances. This will take quadratic time in 150 | the number of samples. The sklearn.cluster.estimate_bandwidth function 151 | can be used to do this more efficiently. 152 | 153 | seeds : array-like, shape=[n_seeds, n_features] or None 154 | Point used as initial kernel locations. 155 | 156 | cluster_all : boolean, default True 157 | If true, then all points are clustered, even those orphans that are 158 | not within any kernel. Orphans are assigned to the nearest kernel. 159 | If false, then orphans are given cluster label -1. 160 | 161 | GPU : bool, default True 162 | Using GPU-based faster mean-shift 163 | 164 | 165 | Returns 166 | ------- 167 | 168 | cluster_centers : array, shape=[n_clusters, n_features] 169 | Coordinates of cluster centers. 170 | 171 | labels : array, shape=[n_samples] 172 | Cluster labels for each point. 173 | 174 | 175 | """ 176 | 177 | if bandwidth is None: 178 | bandwidth = estimate_bandwidth(X) 179 | elif bandwidth <= 0: 180 | raise ValueError("bandwidth needs to be greater than zero or None,\ 181 | got %f" % bandwidth) 182 | if seeds is None: 183 | if GPU == True: 184 | seeds = gpu_seed_generator(X) 185 | 186 | 187 | #adjusted=False 188 | n_samples, n_features = X.shape 189 | center_intensity_dict = {} 190 | nbrs = NearestNeighbors(radius=bandwidth, metric='cosine').fit(X) 191 | #NearestNeighbors(radius=bandwidth, n_jobs=n_jobs, metric='cosine').radius_neighbors() 192 | 193 | global SEED_NUM 194 | if GPU == True: 195 | #GPU ver 196 | while True: 197 | labels, number = meanshift_torch(X, seeds, bandwidth)#gpu calculation 198 | for i in range(len(number)): 199 | if number[i] is not None: 200 | center_intensity_dict[tuple(labels[i])] = number[i]#find out cluster 201 | 202 | if not center_intensity_dict: 203 | # nothing near seeds 204 | raise ValueError("No point was within bandwidth=%f of any seed." 205 | " Try a different seeding strategy \ 206 | or increase the bandwidth." 207 | % bandwidth) 208 | 209 | # POST PROCESSING: remove near duplicate points 210 | # If the distance between two kernels is less than the bandwidth, 211 | # then we have to remove one because it is a duplicate. Remove the 212 | # one with fewer points. 213 | 214 | sorted_by_intensity = sorted(center_intensity_dict.items(), 215 | key=lambda tup: (tup[1], tup[0]), 216 | reverse=True) 217 | sorted_centers = np.array([tup[0] for tup in sorted_by_intensity]) 218 | unique = np.ones(len(sorted_centers), dtype=np.bool) 219 | nbrs = NearestNeighbors(radius=bandwidth, metric='cosine').fit(sorted_centers) 220 | for i, center in enumerate(sorted_centers): 221 | if unique[i]: 222 | neighbor_idxs = nbrs.radius_neighbors([center], 223 | return_distance=False)[0] 224 | unique[neighbor_idxs] = 0 225 | unique[i] = 1 # leave the current point as unique 226 | cluster_centers = sorted_centers[unique] 227 | 228 | 229 | # assign labels 230 | nbrs = NearestNeighbors(n_neighbors=1).fit(cluster_centers) 231 | labels = np.zeros(n_samples, dtype=np.int) 232 | distances, idxs = nbrs.kneighbors(X) 233 | if cluster_all: 234 | labels = idxs.flatten() 235 | else: 236 | labels.fill(-1) 237 | bool_selector = distances.flatten() <= bandwidth 238 | labels[bool_selector] = idxs.flatten()[bool_selector] 239 | 240 | #Test 241 | #break 242 | 243 | bg_num = np.sum(labels==0) 244 | r = 1-bg_num/labels.size 245 | #seed number adjust 246 | dict_len = len(cluster_centers)#cluster number 247 | 248 | M = dict_len 249 | 250 | 251 | if L*M <= SEED_NUM: #safety area 252 | #SEED_NUM -= 200#test 253 | #if H*M <= SEED_NUM: 254 | # SEED_NUM -= M #seeds are too much, adjsut 255 | 256 | break 257 | else: 258 | seeds = gpu_seed_adjust(X)#seeds are too few, adjsut 259 | 260 | return cluster_centers, labels 261 | 262 | 263 | 264 | class MeanShiftEuc(BaseEstimator, ClusterMixin): 265 | """Mean shift clustering using a flat kernel. 266 | 267 | Mean shift clustering aims to discover "blobs" in a smooth density of 268 | samples. It is a centroid-based algorithm, which works by updating 269 | candidates for centroids to be the mean of the points within a given 270 | region. These candidates are then filtered in a post-processing stage to 271 | eliminate near-duplicates to form the final set of centroids. 272 | 273 | Seeding is performed using a binning technique for scalability. 274 | 275 | Read more in the :ref:`User Guide `. 276 | 277 | Parameters 278 | ---------- 279 | bandwidth : float, optional 280 | Bandwidth used in the RBF kernel. 281 | 282 | If not given, the bandwidth is estimated using 283 | sklearn.cluster.estimate_bandwidth; see the documentation for that 284 | function for hints on scalability (see also the Notes, below). 285 | 286 | seeds : array, shape=[n_samples, n_features], optional 287 | Seeds used to initialize kernels. If not set, 288 | the seeds are calculated by clustering.get_bin_seeds 289 | with bandwidth as the grid size and default values for 290 | other parameters. 291 | 292 | cluster_all : boolean, default True 293 | If true, then all points are clustered, even those orphans that are 294 | not within any kernel. Orphans are assigned to the nearest kernel. 295 | If false, then orphans are given cluster label -1. 296 | 297 | GPU : bool, default True 298 | Using GPU-based faster mean-shift 299 | 300 | 301 | Attributes 302 | ---------- 303 | cluster_centers_ : array, [n_clusters, n_features] 304 | Coordinates of cluster centers. 305 | 306 | labels_ : 307 | Labels of each point. 308 | 309 | Examples 310 | -------- 311 | >>> from sklearn.cluster import MeanShift 312 | >>> import numpy as np 313 | >>> X = np.array([[1, 1], [2, 1], [1, 0], 314 | ... [4, 7], [3, 5], [3, 6]]) 315 | >>> clustering = MeanShift(bandwidth=2).fit(X) 316 | >>> clustering.labels_ 317 | array([1, 1, 1, 0, 0, 0]) 318 | >>> clustering.predict([[0, 0], [5, 5]]) 319 | array([1, 0]) 320 | >>> clustering # doctest: +NORMALIZE_WHITESPACE 321 | MeanShift(bandwidth=2, cluster_all=True, seeds=None) 322 | 323 | References 324 | ---------- 325 | 326 | Dorin Comaniciu and Peter Meer, "Mean Shift: A robust approach toward 327 | feature space analysis". IEEE Transactions on Pattern Analysis and 328 | Machine Intelligence. 2002. pp. 603-619. 329 | 330 | """ 331 | def __init__(self, bandwidth=None, seeds=None, cluster_all=True, GPU=True): 332 | self.bandwidth = bandwidth 333 | self.seeds = seeds 334 | self.cluster_all = cluster_all 335 | self.GPU = GPU 336 | 337 | def fit(self, X, y=None): 338 | """Perform clustering. 339 | 340 | Parameters 341 | ----------- 342 | X : array-like, shape=[n_samples, n_features] 343 | Samples to cluster. 344 | 345 | y : Ignored 346 | 347 | """ 348 | X = check_array(X) 349 | self.cluster_centers_, self.labels_ = \ 350 | mean_shift_euc(X, bandwidth=self.bandwidth, seeds=self.seeds, 351 | cluster_all=self.cluster_all, GPU=self.GPU) 352 | return self 353 | 354 | def predict(self, X): 355 | """Predict the closest cluster each sample in X belongs to. 356 | 357 | Parameters 358 | ---------- 359 | X : {array-like, sparse matrix}, shape=[n_samples, n_features] 360 | New data to predict. 361 | 362 | Returns 363 | ------- 364 | labels : array, shape [n_samples,] 365 | Index of the cluster each sample belongs to. 366 | """ 367 | check_is_fitted(self, "cluster_centers_") 368 | 369 | return pairwise_distances_argmin(X, self.cluster_centers_) 370 | --------------------------------------------------------------------------------