Search on the blog

2014年4月27日日曜日

OpenCV日記(9)カーネル法を使った非線形SVM

 OpenCVには機械学習の機能を提供するmlモジュールがあります。この中にSVMがあったので使ってみました。

 線形SVMに関するtutorial[1, 2]はいくつかありましたが、非線形のSVMを扱ったものが無かったのでサンプル実装してみました。パラメータがややこしいですが、公式リファレンス[3]に各パラメータの意味が載っているので必要に応じて参照してください。

出来たもの
赤丸、青丸がtraining dataです。赤領域、青領域がそれぞれSVMによって分離された空間です。灰色の丸で囲っているのがサポートベクトルです。




ソースコード
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>

using namespace cv;

const int WINDOW_HEIGHT = 512;
const int WINDOW_WIDTH = 512;
const int TRAINING_DATA_NUM = 200;

inline int labelOf(double x, double y) {
    // return (x - 0.5) * (x - 0.5) + (y - 0.5) * (y - 0.5) <= 0.1 ? 1 : -1;
    // return y >= sin(10 * x) ? 1 : -1;
    return y >= 10*x*x ? 1 : -1;
}

int main() {
    // Data for visual representation
    Mat image = Mat::zeros(WINDOW_HEIGHT, WINDOW_WIDTH, CV_8UC3);

    // Set up training data
    Mat trainingData(TRAINING_DATA_NUM, 2, CV_32FC1);
    Mat labels(TRAINING_DATA_NUM, 1, CV_32FC1);

    for (int i = 0; i < TRAINING_DATA_NUM; i++) {
        float x = 1. * rand() / RAND_MAX;
        float y = 1. * rand() / RAND_MAX;
        
        trainingData.at<float>(i, 0) = x;
        trainingData.at<float>(i, 1) = y;

        labels.at<float>(i, 0) = labelOf(x, y);
    }

    // Set up SVM's parameters
    CvSVMParams params = CvSVMParams();

    params.svm_type = CvSVM::C_SVC;
    params.kernel_type = CvSVM::RBF;
    params.degree = 0;
    params.gamma = 20;
    params.coef0 = 0;
    params.C = 10;
    params.nu = 0.0;
    params.p = 0.0;
    params.class_weights = NULL;
    params.term_crit.type = CV_TERMCRIT_ITER + CV_TERMCRIT_EPS;
    params.term_crit.max_iter = 1000;
    params.term_crit.epsilon = 1e-6;

    // Train the SVM
    CvSVM SVM;
    SVM.train(trainingData, labels, Mat(), Mat(), params);

    // Show the decision regions given by the SVM
    Vec3b red(0, 0, 128), blue(128, 0, 0);
    for (int i = 0; i < image.rows; ++i) {
        for (int j = 0; j < image.cols; ++j)
        {
            Mat sampleMat = 
                (Mat_<float>(1,2) << 1.*j/WINDOW_HEIGHT, 1.*i/WINDOW_WIDTH);
            float response = SVM.predict(sampleMat);
            if (response == 1)
                image.at<Vec3b>(i,j)  = red;
            else if (response == -1)
                image.at<Vec3b>(i,j)  = blue;
        }
    }

    // Show the training data
    for (int i = 0; i < TRAINING_DATA_NUM; i++) {
        double x = trainingData.at<float>(i, 0) * WINDOW_HEIGHT;
        double y = trainingData.at<float>(i, 1) * WINDOW_WIDTH;

        if (labels.at<float>(i, 0) == 1)
            circle(image, Point(x, y), 3, Scalar(0, 0, 255), -1, 8);
        else
            circle( image, Point(x, y), 3, Scalar(255, 0, 0), -1, 8);
    }

    // Show support vectors
    int c = SVM.get_support_vector_count();
    for (int i = 0; i < c; ++i) {
        const float* v = SVM.get_support_vector(i);
        circle(image,  Point(v[0] * WINDOW_WIDTH, v[1] * WINDOW_HEIGHT), 
               6, Scalar(128, 128, 128), 2, 8);
    }

    imshow("Kernel SVM", image);

    waitKey(0);
}
参考URL
[1] Introduction to Support Vector Machines
[2] Support Vector Machines for Non-Linearly Separable Data
[3] Support Vector Machines — OpenCV 2.4.9.0 documentation

0 件のコメント:

コメントを投稿