Search on the blog

2015年5月2日土曜日

scikit-learn(3) Random Forest

 Random Forestの勉強をした。自分でも実装できる気がする。まぁ実装しないけど。

Random Forestの概要
  1. 複数個の決定木を用いたアンサンブル学習器
  2. 識別問題のときは最頻値を採用
  3. 回帰問題のときは平均値を採用
  4. ブートストラップサンプリングを行い、各決定木には異なる学習データを入力
  5. 各決定木には特徴量のサブセットを入力
4., 5.により、各決定木の相関性を減らすことができる。
特に、4.では学習データに含まれるノイズの影響を、5.では支配的な特徴量の影響によって木の分割が同様になってしまうことを防ぐ効果がある。
相関性の低い複数個の木の予測値を総合的に評価することで、決定木の弱点である過学習を解決することができる。

scikit-learnサンプル
前回と同じく「digits」の識別を行った。
from sklearn import datasets, metrics
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import train_test_split

# load data
digits = datasets.load_digits()
X = digits.data
y = digits.target

# split data into (train, test)
trainX, testX, trainY, testY = train_test_split(X, y, test_size=0.2)

# train a classifier
clf = RandomForestClassifier(min_samples_split=1, n_estimators=50, max_features=8)
clf.fit(trainX, trainY)

# predict test data
predictY = clf.predict(testX)
print metrics.classification_report(testY, predictY)
print metrics.confusion_matrix(testY, predictY)
テスト識別率は、98%。
パラメータは適当に選んでいるにも関わらず、よい結果がえられた。 乱数が入っているため結果は決定的ではないが、だいたい96〜98%程度に収まった。
             precision    recall  f1-score   support

          0       1.00      1.00      1.00        29
          1       1.00      1.00      1.00        40
          2       1.00      0.97      0.99        40
          3       0.98      0.93      0.95        45
          4       1.00      0.95      0.97        41
          5       0.97      1.00      0.99        33
          6       1.00      0.97      0.98        33
          7       0.92      1.00      0.96        36
          8       0.91      0.97      0.94        32
          9       1.00      1.00      1.00        31

avg / total       0.98      0.98      0.98       360

[[29  0  0  0  0  0  0  0  0  0]
 [ 0 40  0  0  0  0  0  0  0  0]
 [ 0  0 39  0  0  0  0  0  1  0]
 [ 0  0  0 42  0  1  0  1  1  0]
 [ 0  0  0  0 39  0  0  2  0  0]
 [ 0  0  0  0  0 33  0  0  0  0]
 [ 0  0  0  0  0  0 32  0  1  0]
 [ 0  0  0  0  0  0  0 36  0  0]
 [ 0  0  0  1  0  0  0  0 31  0]
 [ 0  0  0  0  0  0  0  0  0 31]]

0 件のコメント:

コメントを投稿