Ex 3: Plot the decision surface of a decision tree on the iris dataset
此範例利用決策樹分類器將資料集進行分類,找出各類別的分類邊界。以鳶尾花資料集當作範例,每次取兩個特徵做訓練,個別繪製不同品種的鳶尾花特徵的分布範圍。對於每對的鳶尾花特徵,決策樹學習推斷出簡單的分類規則,構成決策邊界。

範例目的:

    1.
    資料集:iris 鳶尾花資料集
    2.
    特徵:鳶尾花特徵
    3.
    預測目標:是哪一種鳶尾花
    4.
    機器學習方法:decision tree 決策樹

(一)引入函式庫及內建測試資料庫

    from sklearn.datasets import load_iris將鳶尾花資料庫存入,iris為一個dict型別資料。
    每筆資料中有4個特徵,一次取2個特徵,共有6種排列方式。
    X (特徵資料) 以及 y (目標資料)。
    DecisionTreeClassifier 建立決策樹分類器。
1
import numpy as np
2
import matplotlib.pyplot as plt
3
4
from sklearn.datasets import load_iris
5
from sklearn.tree import DecisionTreeClassifier
6
7
iris = load_iris()
8
9
for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3],
10
[1, 2], [1, 3], [2, 3]]):
11
X = iris.data[:, pair]
12
y = iris.target
Copied!

(二)建立Decision Tree分類器

建立模型及分類器訓練

    DecisionTreeClassifier():決策樹分類器。
    fit(特徵資料, 目標資料):利用特徵資料及目標資料對分類器進行訓練。
1
clf = DecisionTreeClassifier().fit(X, y)
Copied!

(三)繪製決策邊界及訓練點

    np.meshgrid:利用特徵之最大最小值,建立預測用網格 xx, yy
    clf.predict:預估分類結果。
    plt.contourf:繪製決策邊界。
    plt.scatter(X,y):將X、y以點的方式繪製於平面上,c為數據點的顏色,label為圖例。
1
plt.subplot(2, 3, pairidx + 1)
2
3
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
4
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
5
6
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
7
np.arange(y_min, y_max, plot_step))
8
9
10
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #np.c_ 串接兩個list,np.ravel將矩陣變為一維
11
12
Z = Z.reshape(xx.shape)
13
14
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
15
16
plt.xlabel(iris.feature_names[pair[0]])
17
plt.ylabel(iris.feature_names[pair[1]])
18
plt.axis("tight")
19
20
for i, color in zip(range(n_classes), plot_colors):
21
idx = np.where(y == i)
22
plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i],
23
cmap=plt.cm.Paired)
24
25
plt.axis("tight")
Copied!

(四)完整程式碼

1
print(__doc__)
2
3
import numpy as np
4
import matplotlib.pyplot as plt
5
6
from sklearn.datasets import load_iris
7
from sklearn.tree import DecisionTreeClassifier
8
9
# Parameters
10
n_classes = 3
11
plot_colors = "bry"
12
plot_step = 0.02
13
14
# Load data
15
iris = load_iris()
16
17
for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3],
18
[1, 2], [1, 3], [2, 3]]):
19
20
# We only take the two corresponding features
21
X = iris.data[:, pair]
22
y = iris.target
23
# Train
24
clf = DecisionTreeClassifier().fit(X, y)
25
26
# Plot the decision boundary
27
plt.subplot(2, 3, pairidx + 1)
28
29
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
30
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
31
32
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
33
np.arange(y_min, y_max, plot_step))
34
35
36
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #np.c_ 串接兩個list,np.ravel將矩陣變為一維
37
38
Z = Z.reshape(xx.shape)
39
40
41
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
42
43
plt.xlabel(iris.feature_names[pair[0]])
44
plt.ylabel(iris.feature_names[pair[1]])
45
plt.axis("tight")
46
47
48
# Plot the training points
49
for i, color in zip(range(n_classes), plot_colors):
50
idx = np.where(y == i)
51
plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i],
52
cmap=plt.cm.Paired)
53
54
plt.axis("tight")
55
56
plt.suptitle("Decision surface of a decision tree using paired features")
57
plt.legend()
58
plt.show()
Copied!