Ex 3: Plotting Validation Curves

模型選擇/範例3 : Plotting Validation Curves

此範例的目的:
    分析SVM不同的kernel參數:gamma,擬合模型的情況
    經過圖示觀察不同gamma與擬合的結果

一、引入函式與模型

    validation_curve用於展示某一個因子,在不同值的情況下所得的score。透過這個曲線,可以直觀地看出模型中不同的參數之擬合的情況
1
import matplotlib.pyplot as plt
2
import numpy as np
3
4
from sklearn.datasets import load_digits
5
from sklearn.svm import SVC
6
from sklearn.model_selection import validation_curve
Copied!

二、建立dataset與模型

    Dataset取自sklearn.datasets.load_digits,內容為0~9的手寫數字,共有1797筆
    load_digits(return_X_y=True)回傳X為data,y為target
    param_range為欲改變的因子:參數gamma的值域,由10的-6次方到10的-1次方之間取5個值作為不同的gamma
1
X, y = load_digits(return_X_y = True)
2
param_range = np.logspace(-6, -1, 5)
3
train_scores, test_scores = validation_curve(
4
SVC(), X, y, param_name="gamma", param_range=param_range,
5
scoring="accuracy", n_jobs=1)
6
train_scores_mean = np.mean(train_scores, axis=1)
7
train_scores_std = np.std(train_scores, axis=1)
8
test_scores_mean = np.mean(test_scores, axis=1)
9
test_scores_std = np.std(test_scores, axis=1)
Copied!

三、作圖:Validation Curve

    plt.semilogx是將X軸改為對數比例
    plt.fill_between將train/test scores的平均值與標準差之間的差距用顏色填滿
    plt.legend為顯示每個數據相對應的圖例名稱,其中loc="best"為顯示圖例名稱的位置,best表示自動分配最佳位置
1
plt.title("Validation Curve with SVM")
2
plt.xlabel(r"$\gammaquot;)
3
plt.ylabel("Score")
4
plt.ylim(0.0, 1.1)
5
lw = 2
6
plt.semilogx(param_range, train_scores_mean, label="Training score",
7
color="darkorange", lw=lw)
8
plt.fill_between(param_range, train_scores_mean - train_scores_std,
9
train_scores_mean + train_scores_std, alpha=0.2,
10
color="darkorange", lw=lw)
11
plt.semilogx(param_range, test_scores_mean, label="Cross-validation score",
12
color="navy", lw=lw)
13
plt.fill_between(param_range, test_scores_mean - test_scores_std,
14
test_scores_mean + test_scores_std, alpha=0.2,
15
color="navy", lw=lw)
16
plt.legend(loc="best")
17
plt.show()
Copied!
png
由上圖的validation curve可以分析出以下三種結果:
    很小的gamma,training scores與validation scores都很低,稱為欠擬合underfitting
    很大的gamma,有好的training scores,但validation scores很低,稱為過擬合overfitting
    適當的gamma,training scores與validation scores都很高,則表示分類器的效果非常好

四、原始碼列表

Python source code: plot_validation_curve.py
1
print(__doc__)
2
3
import matplotlib.pyplot as plt
4
import numpy as np
5
6
from sklearn.datasets import load_digits
7
from sklearn.svm import SVC
8
from sklearn.model_selection import validation_curve
9
10
X, y = load_digits(return_X_y=True)
11
12
param_range = np.logspace(-6, -1, 5)
13
print(param_range)
14
train_scores, test_scores = validation_curve(
15
SVC(), X, y, param_name="gamma", param_range=param_range,
16
scoring="accuracy", n_jobs=1)
17
train_scores_mean = np.mean(train_scores, axis=1)
18
train_scores_std = np.std(train_scores, axis=1)
19
test_scores_mean = np.mean(test_scores, axis=1)
20
test_scores_std = np.std(test_scores, axis=1)
21
22
plt.title("Validation Curve with SVM")
23
plt.xlabel(r"$\gammaquot;)
24
plt.ylabel("Score")
25
plt.ylim(0.0, 1.1)
26
lw = 2
27
plt.semilogx(param_range, train_scores_mean, label="Training score",
28
color="darkorange", lw=lw)
29
plt.fill_between(param_range, train_scores_mean - train_scores_std,
30
train_scores_mean + train_scores_std, alpha=0.2,
31
color="darkorange", lw=lw)
32
plt.semilogx(param_range, test_scores_mean, label="Cross-validation score",
33
color="navy", lw=lw)
34
plt.fill_between(param_range, test_scores_mean - test_scores_std,
35
test_scores_mean + test_scores_std, alpha=0.2,
36
color="navy", lw=lw)
37
plt.legend(loc="best")
38
plt.show()
Copied!
Last modified 7mo ago