Ex 2: Restricted Boltzmann Machine features for digit classification
此範例將使用BernoulliRBM特徵選取方法,提升手寫數字識別的精確率,伯努利限制玻爾茲曼機器模型(`BernoulliRBM
`)將可以對數據做有效的非線性 特徵提取的處理。 為了讓此模型訓練出來更為強健,將輸入的圖檔,分別做上左右下,一像素的平移,用以增加更多訓練資料, 訓練網路的參數是使用grid search演算法,但此訓練太耗費時間,因此不再這重現,。 此範例結果將比較, 1.使用原本的像素值做的邏輯回歸 2.使用BernoulliRBM做特徵選取的邏輯回歸 結果將顯示:使用BernoulliRBM將可以提升分類的準確度。

(一)引入函式庫與資料

1
from __future__ import print_function
2
3
print(__doc__)
4
5
# Authors: Yann N. Dauphin, Vlad Niculae, Gabriel Synnaeve
6
# License: BSD
7
8
import numpy as np
9
import matplotlib.pyplot as plt
10
11
from scipy.ndimage import convolve
12
from sklearn import linear_model, datasets, metrics
13
from sklearn.model_selection import train_test_split
14
from sklearn.neural_network import BernoulliRBM
15
from sklearn.pipeline import Pipeline
Copied!

(二)資料前處理、讀取資料、選取模型

1
def nudge_dataset(X, Y):
2
"""
3
此副函式是用來將輸入資料的數字圖形,分別做上左右下一像素的平移,目的是製造更多的訓練資料讓模型訓練出來更強健
4
"""
5
direction_vectors = [
6
[[0, 1, 0],
7
[0, 0, 0],
8
[0, 0, 0]],
9
10
[[0, 0, 0],
11
[1, 0, 0],
12
[0, 0, 0]],
13
14
[[0, 0, 0],
15
[0, 0, 1],
16
[0, 0, 0]],
17
18
[[0, 0, 0],
19
[0, 0, 0],
20
[0, 1, 0]]]
21
22
shift = lambda x, w: convolve(x.reshape((8, 8)), mode='constant',
23
weights=w).ravel()
24
X = np.concatenate([X] +
25
[np.apply_along_axis(shift, 1, X, vector)
26
for vector in direction_vectors])
27
Y = np.concatenate([Y for _ in range(5)], axis=0)
28
return X, Y
29
30
# Load Data
31
digits = datasets.load_digits()
32
X = np.asarray(digits.data, 'float32')
33
X, Y = nudge_dataset(X, digits.target)
34
X = (X - np.min(X, 0)) / (np.max(X, 0) + 0.0001) # 將灰階影像降尺度降到[0,1]
35
# 將資料切割成訓練集與測試集
36
X_train, X_test, Y_train, Y_test = train_test_split(X, Y,
37
test_size=0.2,
38
random_state=0)
39
40
# Models we will use
41
logistic = linear_model.LogisticRegression()
42
rbm = BernoulliRBM(random_state=0, verbose=True)
43
44
classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)])
Copied!

(三)設定模型參數與訓練模型

1
# 參數選擇需使用cross-validation去比較
2
# 此參數是使用GridSearchCV找出來的. Here we are not performing cross-validation to save time.
3
#GridSratch 就是將參數設定好,跑過全部參數後去找結果最好的一組參數
4
rbm.learning_rate = 0.06
5
rbm.n_iter = 20
6
#.n_components = 100 表示隱藏層單元為100,即表示萃取出100個特徵,特徵萃取的越多準確率會越高,但越耗時間
7
rbm.n_components = 100
8
logistic.C = 6000.0
9
10
# Training RBM-Logistic Pipeline
11
classifier.fit(X_train, Y_train)
12
13
# Training Logistic regression
14
logistic_classifier = linear_model.LogisticRegression(C=100.0)
15
logistic_classifier.fit(X_train, Y_train)
Copied!

(四)評估模型的分辨準確率

1
print()
2
print("Logistic regression using RBM features:\n%s\n" % (
3
metrics.classification_report(
4
Y_test,
5
classifier.predict(X_test))))
6
7
print("Logistic regression using raw pixel features:\n%s\n" % (
8
metrics.classification_report(
9
Y_test,
10
logistic_classifier.predict(X_test))))
Copied!
圖1:使用RBM演算法後準確率為0.95
圖2:不使用任何特徵選取方法做的做的邏輯回歸準確率0.77

(五)畫出100個RBM萃取出的特徵

1
plt.figure(figsize=(4.2, 4))
2
for i, comp in enumerate(rbm.components_):
3
plt.subplot(10, 10, i + 1)
4
plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r,
5
interpolation='nearest')
6
plt.xticks(())
7
plt.yticks(())
8
plt.suptitle('100 components extracted by RBM', fontsize=16)
9
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
10
11
plt.show()
Copied!
圖3:使用RBM演算法,尋找出來的特徵

(六)完整程式碼

1
from __future__ import print_function
2
3
print(__doc__)
4
5
# Authors: Yann N. Dauphin, Vlad Niculae, Gabriel Synnaeve
6
# License: BSD
7
8
import numpy as np
9
import matplotlib.pyplot as plt
10
11
from scipy.ndimage import convolve
12
from sklearn import linear_model, datasets, metrics
13
from sklearn.model_selection import train_test_split
14
from sklearn.neural_network import BernoulliRBM
15
from sklearn.pipeline import Pipeline
16
17
18
###############################################################################
19
# Setting up
20
21
def nudge_dataset(X, Y):
22
"""
23
This produces a dataset 5 times bigger than the original one,
24
by moving the 8x8 images in X around by 1px to left, right, down, up
25
"""
26
direction_vectors = [
27
[[0, 1, 0],
28
[0, 0, 0],
29
[0, 0, 0]],
30
31
[[0, 0, 0],
32
[1, 0, 0],
33
[0, 0, 0]],
34
35
[[0, 0, 0],
36
[0, 0, 1],
37
[0, 0, 0]],
38
39
[[0, 0, 0],
40
[0, 0, 0],
41
[0, 1, 0]]]
42
43
shift = lambda x, w: convolve(x.reshape((8, 8)), mode='constant',
44
weights=w).ravel()
45
X = np.concatenate([X] +
46
[np.apply_along_axis(shift, 1, X, vector)
47
for vector in direction_vectors])
48
Y = np.concatenate([Y for _ in range(5)], axis=0)
49
return X, Y
50
51
# Load Data
52
digits = datasets.load_digits()
53
X = np.asarray(digits.data, 'float32')
54
X, Y = nudge_dataset(X, digits.target)
55
X = (X - np.min(X, 0)) / (np.max(X, 0) + 0.0001) # 0-1 scaling
56
57
X_train, X_test, Y_train, Y_test = train_test_split(X, Y,
58
test_size=0.2,
59
random_state=0)
60
61
# Models we will use
62
logistic = linear_model.LogisticRegression()
63
rbm = BernoulliRBM(random_state=0, verbose=True)
64
65
classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)])
66
67
###############################################################################
68
# Training
69
70
# Hyper-parameters. These were set by cross-validation,
71
# using a GridSearchCV. Here we are not performing cross-validation to
72
# save time.
73
rbm.learning_rate = 0.06
74
rbm.n_iter = 20
75
# More components tend to give better prediction performance, but larger
76
# fitting time
77
rbm.n_components = 100
78
logistic.C = 6000.0
79
80
# Training RBM-Logistic Pipeline
81
classifier.fit(X_train, Y_train)
82
83
# Training Logistic regression
84
logistic_classifier = linear_model.LogisticRegression(C=100.0)
85
logistic_classifier.fit(X_train, Y_train)
86
87
###############################################################################
88
# Evaluation
89
90
print()
91
print("Logistic regression using RBM features:\n%s\n" % (
92
metrics.classification_report(
93
Y_test,
94
classifier.predict(X_test))))
95
96
print("Logistic regression using raw pixel features:\n%s\n" % (
97
metrics.classification_report(
98
Y_test,
99
logistic_classifier.predict(X_test))))
100
101
###############################################################################
102
# Plotting
103
104
plt.figure(figsize=(4.2, 4))
105
for i, comp in enumerate(rbm.components_):
106
plt.subplot(10, 10, i + 1)
107
plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r,
108
interpolation='nearest')
109
plt.xticks(())
110
plt.yticks(())
111
plt.suptitle('100 components extracted by RBM', fontsize=16)
112
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
113
114
plt.show()
Copied!