昨日に引き続き、scikit-learnの勉強。
今回は”手書き数字”を識別するデータセット「digits」にチャレンジした。
このデータセットは「iris」とは異なり、デフォルトパラメータで識別を行なっても良い結果はえられない。
ということでパラメータのチューニングを行わなければならない。
パラメータの候補を渡すとcross validationを行って良いパラメータを選んでくれる関数があったのでそれを使用した。
from sklearn import svm, datasets, grid_search, metrics
from numpy import logspace
# load data
digits = datasets.load_digits()
X = digits.data
Y = digits.target
# split data into (training, test)
train_data_num = 1000
trainX, trainY = X[:train_data_num], Y[:train_data_num]
testX, testY = X[train_data_num:], Y[train_data_num:]
# train SVM
# use cross validation to choose good parameters from grid points
parameters = {
'C' : logspace(-10, 10, base=2),
'gamma' : logspace(-10, 10, base=2)
}
grsrch = grid_search.GridSearchCV(svm.SVC(), parameters)
grsrch.fit(trainX[:100], trainY[:100])
# get an estimator with the best parameters
clf = grsrch.best_estimator_
clf.fit(trainX, trainY)
# predict test data
predictY = clf.predict(testX)
print metrics.classification_report(testY, predictY)
print metrics.confusion_matrix(testY, predictY)
パラメータCの候補は、2^{-10}, 2^{-9}, 2^{-8}, .... 2^{10}、
パラメータgammaの候補は、2^{-10}, 2^{-9}, 2^{-8}, .... 2^{10}
とした。
まず学習データのうち100個だけを使って、パラメータチューニングを実施した。
その後最良のパラメータを持つSVMをすべての学習データで学習させた。
結果は、以下のとおり。テストデータの識別率は97%。
precision recall f1-score support
0 1.00 0.99 0.99 79
1 0.99 0.96 0.97 80
2 0.99 0.99 0.99 77
3 0.97 0.86 0.91 79
4 0.99 0.95 0.97 83
5 0.94 0.99 0.96 82
6 0.99 0.99 0.99 80
7 0.95 0.99 0.97 80
8 0.94 1.00 0.97 76
9 0.94 0.98 0.96 81
avg / total 0.97 0.97 0.97 797
[[78 0 0 0 1 0 0 0 0 0]
[ 0 77 1 0 0 0 0 0 1 1]
[ 0 0 76 1 0 0 0 0 0 0]
[ 0 0 0 68 0 3 0 4 4 0]
[ 0 0 0 0 79 0 0 0 0 4]
[ 0 0 0 0 0 81 1 0 0 0]
[ 0 1 0 0 0 0 79 0 0 0]
[ 0 0 0 0 0 1 0 79 0 0]
[ 0 0 0 0 0 0 0 0 76 0]
[ 0 0 0 1 0 1 0 0 0 79]]
done.