EX 2: Mean-shift 群聚法.md
此範例展示一種強建的特徵空間分析法
    1.
    利用 make_blobs 來建立所需的樣本點
    2.
    利用均值漂移算法找到各類質心集合
    3.
    通過找到給定樣本的最近質心來給新樣本上標籤

(一)引入函式庫

引入函式如下:
    1.
    numpy : 產生陣列數值
    2.
    matplotlib.pyplot : 用來繪製影像
    3.
    sklearn.cluster import MeanShift, estimate_bandwidth : MeanShift:發現樣本的平滑密度中的點 ; estimate_bandwidth:計算要用於maen-shift演算法的頻寬
    4.
    sklearn.datasets.samples_generator import make_blobs : 產生用於clustering的等向高斯分布點
    5.
    itertools import cycle : 產生一個迭代器,對迭代器中的元素反覆執行
1
import numpy as np
2
from sklearn.cluster import MeanShift, estimate_bandwidth
3
from sklearn.datasets.samples_generator import make_blobs
Copied!
1
# Generate sample data
2
centers = [[1, 1], [-1, -1], [1, -1]]
3
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)
Copied!
根據提供的3個中心點,產生各10000個等向高斯的點

(二)Clustering

1
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)
2
3
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
4
ms.fit(X)
5
labels = ms.labels_
6
cluster_centers = ms.cluster_centers_
7
8
labels_unique = np.unique(labels)
9
n_clusters_ = len(labels_unique)
10
11
print("number of estimated clusters : %d" % n_clusters_)
Copied!
estimate_bandwidth 算出的 bandwidth 會用來作為提供 RBF krenel 的參數,用在 MeanShift 的 bandwidth 參數裡面 RBF kernel : 主要用於線性不可分的情形,將資料投射到更高維的空間,讓其變得可以線性分割 做聚集後就可得各類別的中心點,以及各點的label
1
# Plot result
2
import matplotlib.pyplot as plt
3
from itertools import cycle
4
5
plt.figure(1)
6
plt.clf()
7
8
colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
9
for k, col in zip(range(n_clusters_), colors):
10
my_members = labels == k
11
cluster_center = cluster_centers[k]
12
plt.plot(X[my_members, 0], X[my_members, 1], col + '.')
13
plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col,
14
markeredgecolor='k', markersize=14)
15
plt.title('Estimated number of clusters: %d' % n_clusters_)
16
plt.show()
Copied!
colors : 在這用作圖形顏色切換 plt.plot(X[my_members, 0], X[my_members, 1], col + '.') : 畫出個別label的點 plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col,markeredgecolor='k', markersize=14) : 畫出個別label的中心 最後秀出結果圖

(三)完整程式碼

Python source code:plot_mean_shift.py
1
"""
2
=============================================
3
A demo of the mean-shift clustering algorithm
4
=============================================
5
6
Reference:
7
8
Dorin Comaniciu and Peter Meer, "Mean Shift: A robust approach toward
9
feature space analysis". IEEE Transactions on Pattern Analysis and
10
Machine Intelligence. 2002. pp. 603-619.
11
12
"""
13
print(__doc__)
14
15
import numpy as np
16
from sklearn.cluster import MeanShift, estimate_bandwidth
17
from sklearn.datasets.samples_generator import make_blobs
18
19
# #############################################################################
20
# Generate sample data
21
centers = [[1, 1], [-1, -1], [1, -1]]
22
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)
23
24
# #############################################################################
25
# Compute clustering with MeanShift
26
27
# The following bandwidth can be automatically detected using
28
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)
29
30
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
31
ms.fit(X)
32
labels = ms.labels_
33
cluster_centers = ms.cluster_centers_
34
35
labels_unique = np.unique(labels)
36
n_clusters_ = len(labels_unique)
37
38
print("number of estimated clusters : %d" % n_clusters_)
39
40
# #############################################################################
41
# Plot result
42
import matplotlib.pyplot as plt
43
from itertools import cycle
44
45
plt.figure(1)
46
plt.clf()
47
48
colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
49
for k, col in zip(range(n_clusters_), colors):
50
my_members = labels == k
51
cluster_center = cluster_centers[k]
52
plt.plot(X[my_members, 0], X[my_members, 1], col + '.')
53
plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col,
54
markeredgecolor='k', markersize=14)
55
plt.title('Estimated number of clusters: %d' % n_clusters_)
56
plt.show()
Copied!