学生時代は、quadratic programmingのソルバーだけあって、自分でSVM書いて、カーネルの機能入れて、マルチクラスに対応させて、...とやっていたけど、今はライブラリ使えば中身を理解してなくても識別問題が解けてしまう。嬉しいような寂しいような...
出来たもの
象限ごとにトレーニングデータをラベリングした場合、同心円の半径ごとにトレーニングデータをラベリングした場合の学習結果。
ソースコード
前回書いたbinary SVMのサンプルを少し書き換えただけ。
#include <opencv2/core/core.hpp> #include <opencv2/highgui/highgui.hpp> #include <opencv2/ml/ml.hpp> using namespace cv; const int WINDOW_HEIGHT = 512; const int WINDOW_WIDTH = 512; const int TRAINING_DATA_NUM = 200; const Vec3b REGION_COLOR[] = { Vec3b(0, 0, 128), Vec3b(0, 128, 0), Vec3b(128, 0, 0), Vec3b(0, 64, 128) }; const Scalar DATA_COLOR[] = { Scalar(0, 0, 255), Scalar(0, 255, 0), Scalar(255, 0, 0), Scalar(0, 128, 255) }; inline int labelOf(double x, double y) { // concentric circle /* double d = (x - 0.5) * (x - 0.5) + (y - 0.5) * (y - 0.5); if (d < 0.075) return 0; if (d < 0.15) return 1; if (d < 0.22) return 2; return 3; */ // quadrant x -= 0.5, y -= 0.5; return (x >= 0) * 2 + (y >= 0); } /** * マルチクラス非線形SVMのデモ. */ int main() { // Data for visual representation Mat image = Mat::zeros(WINDOW_HEIGHT, WINDOW_WIDTH, CV_8UC3); // Set up training data Mat trainingData(TRAINING_DATA_NUM, 2, CV_32FC1); Mat labels(TRAINING_DATA_NUM, 1, CV_32FC1); for (int i = 0; i < TRAINING_DATA_NUM; i++) { float x = 1. * rand() / RAND_MAX; float y = 1. * rand() / RAND_MAX; trainingData.at<float>(i, 0) = x; trainingData.at<float>(i, 1) = y; labels.at<float>(i, 0) = labelOf(x, y); } // Set up SVM's parameters CvSVMParams params = CvSVMParams(); params.svm_type = CvSVM::C_SVC; params.kernel_type = CvSVM::RBF; params.degree = 0; params.gamma = 20; params.coef0 = 0; params.C = 10; params.nu = 0.0; params.p = 0.0; params.class_weights = NULL; params.term_crit.type = CV_TERMCRIT_ITER + CV_TERMCRIT_EPS; params.term_crit.max_iter = 1000; params.term_crit.epsilon = 1e-6; // Train the SVM CvSVM SVM; SVM.train(trainingData, labels, Mat(), Mat(), params); // Show the decision regions given by the SVM for (int i = 0; i < image.rows; ++i) { for (int j = 0; j < image.cols; ++j) { Mat sampleMat = (Mat_<float>(1,2) << 1.*j/WINDOW_HEIGHT, 1.*i/WINDOW_WIDTH); int response = SVM.predict(sampleMat); image.at<Vec3b>(i,j) = REGION_COLOR[response]; } } // Show the training data for (int i = 0; i < TRAINING_DATA_NUM; i++) { double x = trainingData.at<float>(i, 0) * WINDOW_HEIGHT; double y = trainingData.at<float>(i, 1) * WINDOW_WIDTH; int label = labels.at<float>(i, 0); circle(image, Point(x, y), 3, DATA_COLOR[label], -1, 8); } // Show support vectors int c = SVM.get_support_vector_count(); for (int i = 0; i < c; ++i) { const float* v = SVM.get_support_vector(i); circle(image, Point(v[0] * WINDOW_WIDTH, v[1] * WINDOW_HEIGHT), 6, Scalar(128, 128, 128), 2, 8); } imshow("Kernel SVM", image); waitKey(0); }
0 件のコメント:
コメントを投稿