├── requirements.txt
├── FMS-Euc-git
├── meanshift
│ ├── __pycache__
│ │ ├── batch_seed.cpython-38.pyc
│ │ └── mean_shift_gpu.cpython-38.pyc
│ ├── batch_seed.py
│ └── mean_shift_gpu.py
├── FMS-Euc.sln
├── main.py
└── FMS-Euc.pyproj
└── README.md
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.19.5
2 | sklearn>=0.24.1
3 | torch>=1.8.1+cu111
4 |
--------------------------------------------------------------------------------
/FMS-Euc-git/meanshift/__pycache__/batch_seed.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/masqm/Faster-Mean-Shift-Euc/HEAD/FMS-Euc-git/meanshift/__pycache__/batch_seed.cpython-38.pyc
--------------------------------------------------------------------------------
/FMS-Euc-git/meanshift/__pycache__/mean_shift_gpu.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/masqm/Faster-Mean-Shift-Euc/HEAD/FMS-Euc-git/meanshift/__pycache__/mean_shift_gpu.cpython-38.pyc
--------------------------------------------------------------------------------
/FMS-Euc-git/FMS-Euc.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio Version 16
4 | VisualStudioVersion = 16.0.31205.134
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "FMS-Euc", "FMS-Euc.pyproj", "{4F293353-3B33-42A6-9A0C-C959F8D39DD7}"
7 | EndProject
8 | Global
9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
10 | Debug|Any CPU = Debug|Any CPU
11 | Release|Any CPU = Release|Any CPU
12 | EndGlobalSection
13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
14 | {4F293353-3B33-42A6-9A0C-C959F8D39DD7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
15 | {4F293353-3B33-42A6-9A0C-C959F8D39DD7}.Release|Any CPU.ActiveCfg = Release|Any CPU
16 | EndGlobalSection
17 | GlobalSection(SolutionProperties) = preSolution
18 | HideSolutionNode = FALSE
19 | EndGlobalSection
20 | GlobalSection(ExtensibilityGlobals) = postSolution
21 | SolutionGuid = {C00CD930-BF0E-44D7-A474-D9EB35BCC5DA}
22 | EndGlobalSection
23 | EndGlobal
24 |
--------------------------------------------------------------------------------
/FMS-Euc-git/main.py:
--------------------------------------------------------------------------------
1 | import time
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 |
5 | #from sklearn.cluster import MeanShift
6 | from sklearn import datasets, cluster
7 | from sklearn.preprocessing import StandardScaler
8 |
9 | from meanshift.mean_shift_gpu import MeanShiftEuc
10 |
11 | def main():
12 | # Generate a blob dataset.
13 | n_samples = 100000
14 | blobs = datasets.make_blobs(n_samples=n_samples, random_state=9)
15 |
16 | # Normalize dataset for easier parameter selection
17 | X, y = blobs
18 | X = StandardScaler().fit_transform(X)
19 |
20 | # Estimate bandwidth for mean shift(Select 1000 points)
21 | bandwidth = cluster.estimate_bandwidth(X[0:999])
22 | bandwidth_gpu = 2*bandwidth/(X.max()-X.min())
23 |
24 | # Obtain results
25 | ms = MeanShiftEuc(bandwidth=bandwidth_gpu, cluster_all=True, GPU=True)
26 | ms.fit(X)
27 | labels = ms.labels_
28 |
29 | return 0
30 |
31 | if __name__ == "__main__":
32 | # execute only if run as a script
33 | main()
34 |
--------------------------------------------------------------------------------
/FMS-Euc-git/FMS-Euc.pyproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Debug
5 | 2.0
6 | {4f293353-3b33-42a6-9a0c-c959f8d39dd7}
7 |
8 | main.py
9 | meanshift\
10 | .
11 | .
12 | {888888a0-9f3d-457c-b088-3a5042f75d52}
13 | Standard Python launcher
14 | Global|ContinuumAnalytics|Anaconda38-64
15 |
16 |
17 |
18 |
19 | 10.0
20 |
21 |
22 |
23 | Code
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/FMS-Euc-git/meanshift/batch_seed.py:
--------------------------------------------------------------------------------
1 | # Author Mengyang Zhao
2 |
3 | import math
4 | import operator
5 |
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 |
9 | import torch
10 | from torch import exp, sqrt
11 |
12 | def euc_batch(a, b):
13 | result = sqrt(((b[None,:] - a[:,None]) ** 2).sum(2))
14 | #pdist = torch.nn.PairwiseDistance(p=2)
15 | #result = pdist(a, b)
16 | return result
17 |
18 | #num = a@b.T
19 | #denom = torch.norm(a, dim=1).reshape(-1, 1) * torch.norm(b, dim=1)
20 | #return num / denom
21 |
22 | def get_weight(sim, bandwidth):
23 |
24 | thr = 1-bandwidth
25 | #max = torch.tensor(1.0e+10).double().cuda()
26 | max = torch.tensor(1.0).double().cuda()
27 | min = torch.tensor(0.0).double().cuda()
28 | #dis=torch.where(sim>thr, 1-sim, max)
29 | dis=torch.where(sim>thr, max, min)
30 |
31 | return dis
32 |
33 | def gaussian(dist, bandwidth):
34 | return exp(-0.5 * ((dist / bandwidth))**2) / (bandwidth * math.sqrt(2 * math.pi))
35 |
36 | def meanshift_torch(data, seed , bandwidth, max_iter=300):
37 |
38 | stop_thresh = 1e-3 * bandwidth
39 | iter=0
40 |
41 | X = torch.from_numpy(np.copy(data)).double().cuda()
42 | S = torch.from_numpy(np.copy(seed)).double().cuda()
43 | B = torch.tensor(bandwidth).double().cuda()
44 |
45 | while True:
46 | #cosine = cos_batch(S, X)
47 |
48 | weight = gaussian(euc_batch(S, X),B)
49 |
50 | #torch.where(distances>(1-bandwidth))
51 | #weight = gaussian(distances, B)
52 | num = (weight[:, :, None] * X).sum(dim=1)
53 | S_old = S
54 | S = num / weight.sum(1)[:, None]
55 | #cosine2 = torch.norm(S - S_old, dim=1).mean()
56 | iter+=1
57 |
58 | if (torch.norm(S - S_old, dim=1).mean() < stop_thresh or iter == max_iter):
59 | break
60 |
61 | p_num=[]
62 | for line in weight:
63 | p_num.append(line[line==1].size()[0])
64 |
65 | my_mean = S.cpu().numpy()
66 |
67 | return my_mean, p_num
68 |
69 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Faster-Mean-Shift-Euc
2 | Faster Mean-shift algorithm with Euclidean Distance Metrics. The algorithm is based on GPU acceleration, which can achieve satisfactory speedup with optimized GPU memory consumption. Here is a brief introduction to how to run it. And the details of our algorithm, please refer to our [paper1](https://arxiv.org/abs/2112.13891), [paper2](https://doi.org/10.1016/j.media.2021.102048).
3 |
4 | The Cosine Metrics version is provided in another repository [Faster-Mean-Shift](https://github.com/masqm/Faster-Mean-Shift).
5 |
6 |
7 | ## Environment
8 | Win10
9 |
10 | VS2019
11 |
12 | Anacoda 2020.11
13 |
14 | The packages requirement please see [requirements.txt](https://github.com/masqm/Faster-Mean-Shift-Euc/blob/main/requirements.txt)
15 |
16 | Please make sure to use a compatible pytorch-gpu version.
17 |
18 | ## Example
19 | Using our algorithm is similar to calling the meanshift in sklearn.
20 | An example of how to run the algorithm is given in [main.py](https://github.com/masqm/Faster-Mean-Shift-Euc/blob/main/FMS-Euc-git/main.py):
21 |
22 | # Generate a blob dataset.
23 | n_samples = 100000
24 | blobs = datasets.make_blobs(n_samples=n_samples, random_state=9)
25 |
26 | # Normalize dataset for easier parameter selection
27 | X, y = blobs
28 | X = StandardScaler().fit_transform(X)
29 |
30 | # Estimate bandwidth for mean shift(Select 1000 points)
31 | bandwidth = cluster.estimate_bandwidth(X[0:999])
32 | bandwidth_gpu = 2*bandwidth/(X.max()-X.min())
33 |
34 | # Obtain results
35 | ms = MeanShiftEuc(bandwidth=bandwidth_gpu, cluster_all=True, GPU=True)
36 | ms.fit(X)
37 | labels = ms.labels_
38 |
39 | Here is a helpful Q&A for image segmentation and bandwidth setting: https://github.com/masqm/Faster-Mean-Shift-Euc/issues/1
40 |
41 | If you encounter any problem or find a bug during using, you are very welcome to contact me by (Mengyang.Zhao.TH@dartmouth.edu). If you use this code for your research, please cite our [paper1](https://arxiv.org/abs/2112.13891), [paper2](https://doi.org/10.1016/j.media.2021.102048). Thanks!
42 |
--------------------------------------------------------------------------------
/FMS-Euc-git/meanshift/mean_shift_gpu.py:
--------------------------------------------------------------------------------
1 | """Mean shift clustering algorithm.
2 |
3 | Mean shift clustering aims to discover *blobs* in a smooth density of
4 | samples. It is a centroid based algorithm, which works by updating candidates
5 | for centroids to be the mean of the points within a given region. These
6 | candidates are then filtered in a post-processing stage to eliminate
7 | near-duplicates to form the final set of centroids.
8 |
9 | Seeding is performed using a binning technique for scalability.
10 | """
11 |
12 | # Author Mengyang Zhao
13 |
14 | # Based on: Conrad Lee
15 | # Alexandre Gramfort
16 | # Gael Varoquaux
17 | # Martino Sorbaro
18 |
19 | import numpy as np
20 | import warnings
21 | import math
22 |
23 | from collections import defaultdict
24 | #from sklearn.externals import six
25 | from sklearn.utils.validation import check_is_fitted
26 | from sklearn.utils import check_random_state, gen_batches, check_array
27 | from sklearn.base import BaseEstimator, ClusterMixin
28 | from sklearn.neighbors import NearestNeighbors
29 | from sklearn.metrics.pairwise import pairwise_distances_argmin
30 | from joblib import Parallel
31 | from joblib import delayed
32 |
33 | from meanshift.batch_seed import meanshift_torch
34 | from random import shuffle
35 |
36 |
37 | #seeds number intital
38 | SEED_NUM = 128
39 | L=8
40 | H=32
41 |
42 |
43 | def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0, n_jobs=None):
44 | """Estimate the bandwidth to use with the mean-shift algorithm.
45 |
46 | That this function takes time at least quadratic in n_samples. For large
47 | datasets, it's wise to set that parameter to a small value.
48 |
49 | Parameters
50 | ----------
51 | X : array-like, shape=[n_samples, n_features]
52 | Input points.
53 |
54 | quantile : float, default 0.3
55 | should be between [0, 1]
56 | 0.5 means that the median of all pairwise distances is used.
57 |
58 | n_samples : int, optional
59 | The number of samples to use. If not given, all samples are used.
60 |
61 | random_state : int, RandomState instance or None (default)
62 | The generator used to randomly select the samples from input points
63 | for bandwidth estimation. Use an int to make the randomness
64 | deterministic.
65 | See :term:`Glossary `.
66 |
67 | n_jobs : int or None, optional (default=None)
68 | The number of parallel jobs to run for neighbors search.
69 | ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
70 | ``-1`` means using all processors. See :term:`Glossary `
71 | for more details.
72 |
73 | Returns
74 | -------
75 | bandwidth : float
76 | The bandwidth parameter.
77 | """
78 | X = check_array(X)
79 |
80 | random_state = check_random_state(random_state)
81 | if n_samples is not None:
82 | idx = random_state.permutation(X.shape[0])[:n_samples]
83 | X = X[idx]
84 | n_neighbors = int(X.shape[0] * quantile)
85 | if n_neighbors < 1: # cannot fit NearestNeighbors with n_neighbors = 0
86 | n_neighbors = 1
87 | nbrs = NearestNeighbors(n_neighbors=n_neighbors,
88 | n_jobs=n_jobs)
89 | nbrs.fit(X)
90 |
91 | bandwidth = 0.
92 | for batch in gen_batches(len(X), 500):
93 | d, _ = nbrs.kneighbors(X[batch, :], return_distance=True)
94 | bandwidth += np.max(d, axis=1).sum()
95 |
96 | return bandwidth / X.shape[0]
97 |
98 |
99 | def gpu_seed_generator(codes):
100 |
101 | seed_indizes = list(range(codes.shape[0]))
102 | shuffle(seed_indizes)
103 | seed_indizes = seed_indizes[:SEED_NUM]
104 | seeds = codes[seed_indizes]
105 |
106 | return seeds
107 |
108 | def gpu_seed_adjust(codes):
109 | global SEED_NUM
110 | SEED_NUM *= 2
111 |
112 | return gpu_seed_generator(codes)
113 |
114 | def get_N(P,r,I):
115 |
116 | #There is no foreground instances
117 | if r<0.1:
118 | return 32 #Allocated some seeds at least
119 |
120 | lnp = math.log(P,math.e)
121 | num=math.log(1-math.e**(lnp/I),math.e)
122 | den = math.log(1-r/I,math.e)
123 | result = num/den
124 |
125 | if result<32:
126 | result =32 #Allocated some seeds at least
127 | elif result>256:
128 | result =256 #Our GPU memory's max limitation, you can higher it.
129 |
130 | return int(result)
131 |
132 |
133 | def mean_shift_euc(X, bandwidth=None, seeds=None,
134 | cluster_all=True, GPU=True):
135 | """Perform mean shift clustering of data using a flat kernel.
136 |
137 | Read more in the :ref:`User Guide `.
138 |
139 | Parameters
140 | ----------
141 |
142 | X : array-like, shape=[n_samples, n_features]
143 | Input data.
144 |
145 | bandwidth : float, optional
146 | Kernel bandwidth.
147 |
148 | If bandwidth is not given, it is determined using a heuristic based on
149 | the median of all pairwise distances. This will take quadratic time in
150 | the number of samples. The sklearn.cluster.estimate_bandwidth function
151 | can be used to do this more efficiently.
152 |
153 | seeds : array-like, shape=[n_seeds, n_features] or None
154 | Point used as initial kernel locations.
155 |
156 | cluster_all : boolean, default True
157 | If true, then all points are clustered, even those orphans that are
158 | not within any kernel. Orphans are assigned to the nearest kernel.
159 | If false, then orphans are given cluster label -1.
160 |
161 | GPU : bool, default True
162 | Using GPU-based faster mean-shift
163 |
164 |
165 | Returns
166 | -------
167 |
168 | cluster_centers : array, shape=[n_clusters, n_features]
169 | Coordinates of cluster centers.
170 |
171 | labels : array, shape=[n_samples]
172 | Cluster labels for each point.
173 |
174 |
175 | """
176 |
177 | if bandwidth is None:
178 | bandwidth = estimate_bandwidth(X)
179 | elif bandwidth <= 0:
180 | raise ValueError("bandwidth needs to be greater than zero or None,\
181 | got %f" % bandwidth)
182 | if seeds is None:
183 | if GPU == True:
184 | seeds = gpu_seed_generator(X)
185 |
186 |
187 | #adjusted=False
188 | n_samples, n_features = X.shape
189 | center_intensity_dict = {}
190 | nbrs = NearestNeighbors(radius=bandwidth, metric='cosine').fit(X)
191 | #NearestNeighbors(radius=bandwidth, n_jobs=n_jobs, metric='cosine').radius_neighbors()
192 |
193 | global SEED_NUM
194 | if GPU == True:
195 | #GPU ver
196 | while True:
197 | labels, number = meanshift_torch(X, seeds, bandwidth)#gpu calculation
198 | for i in range(len(number)):
199 | if number[i] is not None:
200 | center_intensity_dict[tuple(labels[i])] = number[i]#find out cluster
201 |
202 | if not center_intensity_dict:
203 | # nothing near seeds
204 | raise ValueError("No point was within bandwidth=%f of any seed."
205 | " Try a different seeding strategy \
206 | or increase the bandwidth."
207 | % bandwidth)
208 |
209 | # POST PROCESSING: remove near duplicate points
210 | # If the distance between two kernels is less than the bandwidth,
211 | # then we have to remove one because it is a duplicate. Remove the
212 | # one with fewer points.
213 |
214 | sorted_by_intensity = sorted(center_intensity_dict.items(),
215 | key=lambda tup: (tup[1], tup[0]),
216 | reverse=True)
217 | sorted_centers = np.array([tup[0] for tup in sorted_by_intensity])
218 | unique = np.ones(len(sorted_centers), dtype=np.bool)
219 | nbrs = NearestNeighbors(radius=bandwidth, metric='cosine').fit(sorted_centers)
220 | for i, center in enumerate(sorted_centers):
221 | if unique[i]:
222 | neighbor_idxs = nbrs.radius_neighbors([center],
223 | return_distance=False)[0]
224 | unique[neighbor_idxs] = 0
225 | unique[i] = 1 # leave the current point as unique
226 | cluster_centers = sorted_centers[unique]
227 |
228 |
229 | # assign labels
230 | nbrs = NearestNeighbors(n_neighbors=1).fit(cluster_centers)
231 | labels = np.zeros(n_samples, dtype=np.int)
232 | distances, idxs = nbrs.kneighbors(X)
233 | if cluster_all:
234 | labels = idxs.flatten()
235 | else:
236 | labels.fill(-1)
237 | bool_selector = distances.flatten() <= bandwidth
238 | labels[bool_selector] = idxs.flatten()[bool_selector]
239 |
240 | #Test
241 | #break
242 |
243 | bg_num = np.sum(labels==0)
244 | r = 1-bg_num/labels.size
245 | #seed number adjust
246 | dict_len = len(cluster_centers)#cluster number
247 |
248 | M = dict_len
249 |
250 |
251 | if L*M <= SEED_NUM: #safety area
252 | #SEED_NUM -= 200#test
253 | #if H*M <= SEED_NUM:
254 | # SEED_NUM -= M #seeds are too much, adjsut
255 |
256 | break
257 | else:
258 | seeds = gpu_seed_adjust(X)#seeds are too few, adjsut
259 |
260 | return cluster_centers, labels
261 |
262 |
263 |
264 | class MeanShiftEuc(BaseEstimator, ClusterMixin):
265 | """Mean shift clustering using a flat kernel.
266 |
267 | Mean shift clustering aims to discover "blobs" in a smooth density of
268 | samples. It is a centroid-based algorithm, which works by updating
269 | candidates for centroids to be the mean of the points within a given
270 | region. These candidates are then filtered in a post-processing stage to
271 | eliminate near-duplicates to form the final set of centroids.
272 |
273 | Seeding is performed using a binning technique for scalability.
274 |
275 | Read more in the :ref:`User Guide `.
276 |
277 | Parameters
278 | ----------
279 | bandwidth : float, optional
280 | Bandwidth used in the RBF kernel.
281 |
282 | If not given, the bandwidth is estimated using
283 | sklearn.cluster.estimate_bandwidth; see the documentation for that
284 | function for hints on scalability (see also the Notes, below).
285 |
286 | seeds : array, shape=[n_samples, n_features], optional
287 | Seeds used to initialize kernels. If not set,
288 | the seeds are calculated by clustering.get_bin_seeds
289 | with bandwidth as the grid size and default values for
290 | other parameters.
291 |
292 | cluster_all : boolean, default True
293 | If true, then all points are clustered, even those orphans that are
294 | not within any kernel. Orphans are assigned to the nearest kernel.
295 | If false, then orphans are given cluster label -1.
296 |
297 | GPU : bool, default True
298 | Using GPU-based faster mean-shift
299 |
300 |
301 | Attributes
302 | ----------
303 | cluster_centers_ : array, [n_clusters, n_features]
304 | Coordinates of cluster centers.
305 |
306 | labels_ :
307 | Labels of each point.
308 |
309 | Examples
310 | --------
311 | >>> from sklearn.cluster import MeanShift
312 | >>> import numpy as np
313 | >>> X = np.array([[1, 1], [2, 1], [1, 0],
314 | ... [4, 7], [3, 5], [3, 6]])
315 | >>> clustering = MeanShift(bandwidth=2).fit(X)
316 | >>> clustering.labels_
317 | array([1, 1, 1, 0, 0, 0])
318 | >>> clustering.predict([[0, 0], [5, 5]])
319 | array([1, 0])
320 | >>> clustering # doctest: +NORMALIZE_WHITESPACE
321 | MeanShift(bandwidth=2, cluster_all=True, seeds=None)
322 |
323 | References
324 | ----------
325 |
326 | Dorin Comaniciu and Peter Meer, "Mean Shift: A robust approach toward
327 | feature space analysis". IEEE Transactions on Pattern Analysis and
328 | Machine Intelligence. 2002. pp. 603-619.
329 |
330 | """
331 | def __init__(self, bandwidth=None, seeds=None, cluster_all=True, GPU=True):
332 | self.bandwidth = bandwidth
333 | self.seeds = seeds
334 | self.cluster_all = cluster_all
335 | self.GPU = GPU
336 |
337 | def fit(self, X, y=None):
338 | """Perform clustering.
339 |
340 | Parameters
341 | -----------
342 | X : array-like, shape=[n_samples, n_features]
343 | Samples to cluster.
344 |
345 | y : Ignored
346 |
347 | """
348 | X = check_array(X)
349 | self.cluster_centers_, self.labels_ = \
350 | mean_shift_euc(X, bandwidth=self.bandwidth, seeds=self.seeds,
351 | cluster_all=self.cluster_all, GPU=self.GPU)
352 | return self
353 |
354 | def predict(self, X):
355 | """Predict the closest cluster each sample in X belongs to.
356 |
357 | Parameters
358 | ----------
359 | X : {array-like, sparse matrix}, shape=[n_samples, n_features]
360 | New data to predict.
361 |
362 | Returns
363 | -------
364 | labels : array, shape [n_samples,]
365 | Index of the cluster each sample belongs to.
366 | """
367 | check_is_fitted(self, "cluster_centers_")
368 |
369 | return pairwise_distances_argmin(X, self.cluster_centers_)
370 |
--------------------------------------------------------------------------------