Ex 5: ROC Curve with Visualization API
Scikit-learn定義了一個簡單的API,用於創建機器學習的可視化。該API的主要功能是無需重新計算即可進行快速繪圖和視覺調整。在此範例中,我們將通過比較ROC曲線來展示如何使用可視化API。

(一)載入資料以及訓練SVC

首先,我們載入load_wine,它主要為一個典型且簡單的多分類資料庫,並將它轉換為二進位制的分類問題。
1
import matplotlib.pyplot as plt
2
from sklearn.svm import SVC
3
from sklearn.ensemble import RandomForestClassifier
4
from sklearn.metrics import plot_roc_curve
5
from sklearn.datasets import load_wine
6
from sklearn.model_selection import train_test_split
7
8
X, y = load_wine(return_X_y=True)
9
y = y == 2
Copied!
對於訓練資料訓練一個SVC。
1
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
2
svc = SVC(random_state=42)
3
svc.fit(X_train, y_train)
Copied!

(二)繪製ROC曲線

使用sklearn.metrics.plot_roc_curve來繪製ROC曲線,回傳的svc_disp對象使我們可以在以後的圖中繼續使用已經計算出的ROC曲線。
1
svc_disp = plot_roc_curve(svc, X_test, y_test)
2
plt.show()
Copied!

(三)訓練一個隨機森林並且繪製ROC曲線

我們訓練一個隨機森林分類器並繪製出ROC曲線來比較先前用SVC繪製的ROC曲線,值得注意的是,svc_disp使用plot繪製曲線,而無需重新計算ROC曲線本身的值。 此外,我們將alpha = 0.8傳遞給繪圖函數以調整曲線的alpha值。
1
rfc = RandomForestClassifier(n_estimators=10, random_state=42)
2
rfc.fit(X_train, y_train)
3
ax = plt.gca()
4
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
5
svc_disp.plot(ax=ax, alpha=0.8)
6
plt.show()
Copied!

(四)完整程式碼

1
================================
2
ROC Curve with Visualization API
3
================================
4
Scikit-learn defines a simple API for creating visualizations for machine
5
learning. The key features of this API is to allow for quick plotting and
6
visual adjustments without recalculation. In this example, we will demonstrate
7
how to use the visualization API by comparing ROC curves.
8
"""
9
print(__doc__)
10
11
##############################################################################
12
# Load Data and Train a SVC
13
# -------------------------
14
# First, we load the wine dataset and convert it to a binary classification
15
# problem. Then, we train a support vector classifier on a training dataset.
16
import matplotlib.pyplot as plt
17
from sklearn.svm import SVC
18
from sklearn.ensemble import RandomForestClassifier
19
from sklearn.metrics import plot_roc_curve
20
from sklearn.datasets import load_wine
21
from sklearn.model_selection import train_test_split
22
23
X, y = load_wine(return_X_y=True)
24
y = y == 2
25
26
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
27
svc = SVC(random_state=42)
28
svc.fit(X_train, y_train)
29
30
##############################################################################
31
# Plotting the ROC Curve
32
# ----------------------
33
# Next, we plot the ROC curve with a single call to
34
# :func:`sklearn.metrics.plot_roc_curve`. The returned `svc_disp` object allows
35
# us to continue using the already computed ROC curve for the SVC in future
36
# plots.
37
svc_disp = plot_roc_curve(svc, X_test, y_test)
38
plt.show()
39
40
##############################################################################
41
# Training a Random Forest and Plotting the ROC Curve
42
# --------------------------------------------------------
43
# We train a random forest classifier and create a plot comparing it to the SVC
44
# ROC curve. Notice how `svc_disp` uses
45
# :func:`~sklearn.metrics.RocCurveDisplay.plot` to plot the SVC ROC curve
46
# without recomputing the values of the roc curve itself. Furthermore, we
47
# pass `alpha=0.8` to the plot functions to adjust the alpha values of the
48
# curves.
49
rfc = RandomForestClassifier(n_estimators=10, random_state=42)
50
rfc.fit(X_train, y_train)
51
ax = plt.gca()
52
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
53
svc_disp.plot(ax=ax, alpha=0.8)
54
plt.show()
Copied!
Last modified 1yr ago