線形SVMに関するtutorial[1, 2]はいくつかありましたが、非線形のSVMを扱ったものが無かったのでサンプル実装してみました。パラメータがややこしいですが、公式リファレンス[3]に各パラメータの意味が載っているので必要に応じて参照してください。
出来たもの
赤丸、青丸がtraining dataです。赤領域、青領域がそれぞれSVMによって分離された空間です。灰色の丸で囲っているのがサポートベクトルです。
ソースコード
#include <opencv2/core/core.hpp> #include <opencv2/highgui/highgui.hpp> #include <opencv2/ml/ml.hpp> using namespace cv; const int WINDOW_HEIGHT = 512; const int WINDOW_WIDTH = 512; const int TRAINING_DATA_NUM = 200; inline int labelOf(double x, double y) { // return (x - 0.5) * (x - 0.5) + (y - 0.5) * (y - 0.5) <= 0.1 ? 1 : -1; // return y >= sin(10 * x) ? 1 : -1; return y >= 10*x*x ? 1 : -1; } int main() { // Data for visual representation Mat image = Mat::zeros(WINDOW_HEIGHT, WINDOW_WIDTH, CV_8UC3); // Set up training data Mat trainingData(TRAINING_DATA_NUM, 2, CV_32FC1); Mat labels(TRAINING_DATA_NUM, 1, CV_32FC1); for (int i = 0; i < TRAINING_DATA_NUM; i++) { float x = 1. * rand() / RAND_MAX; float y = 1. * rand() / RAND_MAX; trainingData.at<float>(i, 0) = x; trainingData.at<float>(i, 1) = y; labels.at<float>(i, 0) = labelOf(x, y); } // Set up SVM's parameters CvSVMParams params = CvSVMParams(); params.svm_type = CvSVM::C_SVC; params.kernel_type = CvSVM::RBF; params.degree = 0; params.gamma = 20; params.coef0 = 0; params.C = 10; params.nu = 0.0; params.p = 0.0; params.class_weights = NULL; params.term_crit.type = CV_TERMCRIT_ITER + CV_TERMCRIT_EPS; params.term_crit.max_iter = 1000; params.term_crit.epsilon = 1e-6; // Train the SVM CvSVM SVM; SVM.train(trainingData, labels, Mat(), Mat(), params); // Show the decision regions given by the SVM Vec3b red(0, 0, 128), blue(128, 0, 0); for (int i = 0; i < image.rows; ++i) { for (int j = 0; j < image.cols; ++j) { Mat sampleMat = (Mat_<float>(1,2) << 1.*j/WINDOW_HEIGHT, 1.*i/WINDOW_WIDTH); float response = SVM.predict(sampleMat); if (response == 1) image.at<Vec3b>(i,j) = red; else if (response == -1) image.at<Vec3b>(i,j) = blue; } } // Show the training data for (int i = 0; i < TRAINING_DATA_NUM; i++) { double x = trainingData.at<float>(i, 0) * WINDOW_HEIGHT; double y = trainingData.at<float>(i, 1) * WINDOW_WIDTH; if (labels.at<float>(i, 0) == 1) circle(image, Point(x, y), 3, Scalar(0, 0, 255), -1, 8); else circle( image, Point(x, y), 3, Scalar(255, 0, 0), -1, 8); } // Show support vectors int c = SVM.get_support_vector_count(); for (int i = 0; i < c; ++i) { const float* v = SVM.get_support_vector(i); circle(image, Point(v[0] * WINDOW_WIDTH, v[1] * WINDOW_HEIGHT), 6, Scalar(128, 128, 128), 2, 8); } imshow("Kernel SVM", image); waitKey(0); }
参考URL
[1] Introduction to Support Vector Machines[2] Support Vector Machines for Non-Linearly Separable Data
[3] Support Vector Machines — OpenCV 2.4.9.0 documentation
0 件のコメント:
コメントを投稿