Ex 1: Recognizing hand-written digits
MathJax.Hub.Queue(["Typeset",MathJax.Hub]);

分類法/範例一: Recognizing hand-written digits

這個範例用來展示scikit-learn 機器學習套件,如何用SVM演算法來達成手寫的數字辨識
  1. 1.
    利用 make_classification 建立模擬資料
  2. 2.
    利用 sklearn.datasets.load_digits() 來讀取內建資料庫
  3. 3.
    用線性的SVC來做分類,以8x8的影像之像素值來當作特徵(共64個特徵)
  4. 4.
    metrics.classification_report 來提供辨識報表

(一)引入函式庫及內建手寫數字資料庫

引入之函式庫如下
  1. 1.
    matplotlib.pyplot: 用來繪製影像
  2. 2.
    sklearn.datasets: 用來繪入內建之手寫數字資料庫
  3. 3.
    sklearn.svm: SVM 支持向量機之演算法物件
  4. 4.
    sklearn.metrics: 用來評估辨識準確度以及報表的顯示
1
import matplotlib.pyplot as plt
2
from sklearn import datasets, svm, metrics
3
4
# The digits dataset
5
digits = datasets.load_digits()
Copied!
使用datasets.load_digits()將資料存入,digits為一個dict型別資料,我們可以用以下指令來看一下資料的內容。
1
for key,value in digits.items() :
2
try:
3
print (key,value.shape)
4
except:
5
print (key)
Copied!
顯示
說明
('images', (1797L, 8L, 8L))
共有 1797 張影像,影像大小為 8x8
('data', (1797L, 64L))
data 則是將8x8的矩陣攤平成64個元素之一維向量
('target_names', (10L,))
說明10種分類之對應 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
DESCR
資料之描述
('target', (1797L,))
記錄1797張影像各自代表那一個數字
接下來我們試著以下面指令來觀察資料檔,每張影像所對照的實際數字存在digits.target變數中
1
images_and_labels = list(zip(digits.images, digits.target))
2
for index, (image, label) in enumerate(images_and_labels[:4]):
3
plt.subplot(2, 4, index + 1)
4
plt.axis('off')
5
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
6
plt.title('Training: %i' % label)
Copied!

(二)訓練以及分類

接下來的步驟則是使用reshape指令將8x8的影像資料攤平成64x1的矩陣。 接著用classifier = svm.SVC(gamma=0.001)產生一個SVC分類器(Support Vector Classification)。再將一半的資料送入分類器來訓練classifier.fit(資料:898x64, 分類目標:898x1)。SVC之預設kernel function為RBF (radial basis function):
exp(γxx2)\exp(-\gamma |x-x'|^2)
. 其中SVC(gamma=0.001)就是在設定RBF函數裏的
γ\gamma
這個值必需要大於零。最後,再利用後半部份的資料來測試訓練完成之SVC分類機predict(data[n_samples / 2:])將預測結果存入predicted變數,而原先的真實目標資料則存於expected變數,用於下一節之準確度統計。
1
n_samples = len(digits.images)
2
3
# 資料攤平:1797 x 8 x 8 -> 1797 x 64
4
# 這裏的-1代表自動計算,相當於 (n_samples, 64)
5
data = digits.images.reshape((n_samples, -1))
6
7
# 產生SVC分類器
8
classifier = svm.SVC(gamma=0.001)
9
10
# 用前半部份的資料來訓練
11
classifier.fit(data[:n_samples // 2], digits.target[:n_samples // 2])
12
13
expected = digits.target[n_samples // 2:]
14
15
#利用後半部份的資料來測試分類器,共 899筆資料
16
predicted = classifier.predict(data[n_samples // 2:])
Copied!
若是觀察 expectedpredicted 矩陣中之前10個變數可以得到:
  • expected[:10] :[8 8 4 9 0 8 9 8 1 2]
  • predicted[:10]:[8 8 4 9 0 8 9 8 1 2]
這說明了前10個元素中,我們之前訓練完成的分類機,正確的分類了手寫數字資料。那對於全部測試資料的準確度呢?要如何量測?

(三)分類準確度統計

那在判斷準確度方面,我們可以使用一個名為「混淆矩陣」(Confusion matrix)的方式來統計。
1
print("Confusion matrix:\n%s"
2
% metrics.confusion_matrix(expected, predicted))
Copied!
使用sklearn中之metrics物件,metrics.confusion_matrix(真實資料:899, 預測資料:899)可以列出下面矩陣。此矩陣對角線左上方第一個數字 87,代表實際為0且預測為0的總數有87個,同一列(row)第五個元素則代表,實際為0但判斷為4的資料個數為1個。
1
Confusion matrix:
2
[[87 0 0 0 1 0 0 0 0 0]
3
[ 0 88 1 0 0 0 0 0 1 1]
4
[ 0 0 85 1 0 0 0 0 0 0]
5
[ 0 0 0 79 0 3 0 4 5 0]
6
[ 0 0 0 0 88 0 0 0 0 4]
7
[ 0 0 0 0 0 88 1 0 0 2]
8
[ 0 1 0 0 0 0 90 0 0 0]
9
[ 0 0 0 0 0 1 0 88 0 0]
10
[ 0 0 0 0 0 0 0 0 88 0]
11
[ 0 0 0 1 0 1 0 0 0 90]]
Copied!
我們可以利用以下的程式碼將混淆矩陣圖示出來。由圖示可以看出,實際為3時,有數次誤判為5,7,8。
1
def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
2
import numpy as np
3
plt.imshow(cm, interpolation='nearest', cmap=cmap)
4
plt.title(title)
5
plt.colorbar()
6
tick_marks = np.arange(len(digits.target_names))
7
plt.xticks(tick_marks, digits.target_names, rotation=45)
8
plt.yticks(tick_marks, digits.target_names)
9
plt.tight_layout()
10
plt.ylabel('True label')
11
plt.xlabel('Predicted label')
12
13
plt.figure()
14
plot_confusion_matrix(metrics.confusion_matrix(expected, predicted))
Copied!
以手寫影像3為例,我們可以用四個數字來探討判斷的精準度。
  1. 1.
    True Positive(TP,真陽):實際為3且判斷為3,共79個
  2. 2.
    False Positive(FP,偽陽):判斷為3但判斷錯誤,共2個
  3. 3.
    False Negative(FN,偽陰):實際為3但判斷錯誤,共12個
  4. 4.
    True Negative(TN,真陰):實際不為3,且判斷正確。也就是其餘899-79-2-12=885個
而在機器學習理論中,我們通常用以下precision, recall, f1-score來探討精確度。以手寫影像3為例。
  • precision = TP/(TP+FP) = 79/81 = 0.98
  • 判斷為3且實際為3的比例為0.98
  • recall = TP/(TP+FN) = 79/91 = 0.87
  • 實際為3且判斷為3的比例為0.87
  • f1-score 則為以上兩者之「harmonic mean 調和平均數」
  • f1-score= 2 x precision x recall/(recision + recall) = 0.92
metrics物件裏也提供了方便的函式metrics.classification_report(expected, predicted)計算以上統計數據。
1
print("Classification report for classifier %s:\n%s\n"
2
% (classifier, metrics.classification_report(expected, predicted)))
Copied!
此報表最後的 support,則代表著實際為手寫數字的總數。例如實際為3的數字共有91個。
1
precision recall f1-score support
2
3
0 1.00 0.99 0.99 88
4
1 0.99 0.97 0.98 91
5
2 0.99 0.99 0.99 86
6
3 0.98 0.87 0.92 91
7
4 0.99 0.96 0.97 92
8
5 0.95 0.97 0.96 91
9
6 0.99 0.99 0.99 91
10
7 0.96 0.99 0.97 89
11
8 0.94 1.00 0.97 88
12
9 0.93 0.98 0.95 92
13
14
avg / total 0.97 0.97 0.97 899
Copied!
最後,用以下的程式碼可以觀察測試影像以及預測(分類)結果得對應關係。
1
images_and_predictions = list(
2
zip(digits.images[n_samples // 2:], predicted))
3
for index, (image, prediction) in enumerate(images_and_predictions[:4]):
4
plt.subplot(2, 4, index + 5)
5
plt.axis('off')
6
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
7
plt.title('Prediction: %i' % prediction)
8
9
plt.show()
Copied!

(四)完整程式碼

Python source code: plot_digits_classification.py
1
print(__doc__)
2
3
# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
4
# License: BSD 3 clause
5
6
# Standard scientific Python imports
7
import matplotlib.pyplot as plt
8
9
# Import datasets, classifiers and performance metrics
10
from sklearn import datasets, svm, metrics
11
12
# The digits dataset
13
digits = datasets.load_digits()
14
15
# The data that we are interested in is made of 8x8 images of digits, let's
16
# have a look at the first 4 images, stored in the `images` attribute of the
17
# dataset. If we were working from image files, we could load them using
18
# matplotlib.pyplot.imread. Note that each image must have the same size. For these
19
# images, we know which digit they represent: it is given in the 'target' of
20
# the dataset.
21
images_and_labels = list(zip(digits.images, digits.target))
22
for index, (image, label) in enumerate(images_and_labels[:4]):
23
plt.subplot(2, 4, index + 1)
24
plt.axis('off')
25
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
26
plt.title('Training: %i' % label)
27
28
# To apply a classifier on this data, we need to flatten the image, to
29
# turn the data in a (samples, feature) matrix:
30
n_samples = len(digits.images)
31
data = digits.images.reshape((n_samples, -1))
32
33
# Create a classifier: a support vector classifier
34
classifier = svm.SVC(gamma=0.001)
35
36
# We learn the digits on the first half of the digits
37
classifier.fit(data[:n_samples // 2], digits.target[:n_samples // 2])
38
39
# Now predict the value of the digit on the second half:
40
expected = digits.target[n_samples // 2:]
41
predicted = classifier.predict(data[n_samples // 2:])
42
43
print("Classification report for classifier %s:\n%s\n"
44
% (classifier, metrics.classification_report(expected, predicted)))
45
print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted))
46
47
images_and_predictions = list(zip(digits.images[n_samples // 2:], predicted))
48
for index, (image, prediction) in enumerate(images_and_predictions[:4]):
49
plt.subplot(2, 4, index + 5)
50
plt.axis('off')
51
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
52
plt.title('Prediction: %i' % prediction)
53
54
plt.show()
Copied!