Page List

Search on the blog

2014年4月27日日曜日

OpenCV日記(10)マルチクラスSVM

 one against oneとかone against allとかを自分で実装するのかと思っていたけど、training dataのラベリングを{0, 1, 2, ... , k-1}とかすると自動でkクラス識別問題を解いてくれるようだ。
 学生時代は、quadratic programmingのソルバーだけあって、自分でSVM書いて、カーネルの機能入れて、マルチクラスに対応させて、...とやっていたけど、今はライブラリ使えば中身を理解してなくても識別問題が解けてしまう。嬉しいような寂しいような...

出来たもの
象限ごとにトレーニングデータをラベリングした場合、同心円の半径ごとにトレーニングデータをラベリングした場合の学習結果。


ソースコード
前回書いたbinary SVMのサンプルを少し書き換えただけ。
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>

using namespace cv;

const int WINDOW_HEIGHT = 512;
const int WINDOW_WIDTH = 512;
const int TRAINING_DATA_NUM = 200;
const Vec3b REGION_COLOR[] = {
    Vec3b(0, 0, 128),
    Vec3b(0, 128, 0),
    Vec3b(128, 0, 0),
    Vec3b(0, 64, 128)
};
const Scalar DATA_COLOR[] = {
    Scalar(0, 0, 255),
    Scalar(0, 255, 0),
    Scalar(255, 0, 0),
    Scalar(0, 128, 255)
};

inline int labelOf(double x, double y) {
    // concentric circle
    /*
    double d = (x - 0.5) * (x - 0.5) + (y - 0.5) * (y - 0.5);
    if (d < 0.075) return 0;
    if (d < 0.15) return 1;
    if (d < 0.22) return 2;
    return 3;
    */

    // quadrant
    x -= 0.5, y -= 0.5;
    return (x >= 0) * 2 + (y >= 0);
}

/**
 * マルチクラス非線形SVMのデモ.
 */
int main() {
    // Data for visual representation
    Mat image = Mat::zeros(WINDOW_HEIGHT, WINDOW_WIDTH, CV_8UC3);

    // Set up training data
    Mat trainingData(TRAINING_DATA_NUM, 2, CV_32FC1);
    Mat labels(TRAINING_DATA_NUM, 1, CV_32FC1);

    for (int i = 0; i < TRAINING_DATA_NUM; i++) {
        float x = 1. * rand() / RAND_MAX;
        float y = 1. * rand() / RAND_MAX;
        
        trainingData.at<float>(i, 0) = x;
        trainingData.at<float>(i, 1) = y;

        labels.at<float>(i, 0) = labelOf(x, y);
    }

    // Set up SVM's parameters
    CvSVMParams params = CvSVMParams();

    params.svm_type = CvSVM::C_SVC;
    params.kernel_type = CvSVM::RBF;
    params.degree = 0;
    params.gamma = 20;
    params.coef0 = 0;
    params.C = 10;
    params.nu = 0.0;
    params.p = 0.0;
    params.class_weights = NULL;
    params.term_crit.type = CV_TERMCRIT_ITER + CV_TERMCRIT_EPS;
    params.term_crit.max_iter = 1000;
    params.term_crit.epsilon = 1e-6;

    // Train the SVM
    CvSVM SVM;
    SVM.train(trainingData, labels, Mat(), Mat(), params);

    // Show the decision regions given by the SVM
    for (int i = 0; i < image.rows; ++i) {
        for (int j = 0; j < image.cols; ++j)
        {
            Mat sampleMat = 
                (Mat_<float>(1,2) << 1.*j/WINDOW_HEIGHT, 1.*i/WINDOW_WIDTH);
            int response = SVM.predict(sampleMat);
            image.at<Vec3b>(i,j)  = REGION_COLOR[response];
        }
    }

    // Show the training data
    for (int i = 0; i < TRAINING_DATA_NUM; i++) {
        double x = trainingData.at<float>(i, 0) * WINDOW_HEIGHT;
        double y = trainingData.at<float>(i, 1) * WINDOW_WIDTH;

        int label = labels.at<float>(i, 0);
        circle(image, Point(x, y), 3, DATA_COLOR[label], -1, 8);
    }

    // Show support vectors
    int c = SVM.get_support_vector_count();
    for (int i = 0; i < c; ++i) {
        const float* v = SVM.get_support_vector(i);
        circle(image,  Point(v[0] * WINDOW_WIDTH, v[1] * WINDOW_HEIGHT), 
               6, Scalar(128, 128, 128), 2, 8);
    }

    imshow("Kernel SVM", image);

    waitKey(0);
}

0 件のコメント:

コメントを投稿