Page List

Search on the blog

2014年4月27日日曜日

OpenCV日記(10)マルチクラスSVM

 one against oneとかone against allとかを自分で実装するのかと思っていたけど、training dataのラベリングを{0, 1, 2, ... , k-1}とかすると自動でkクラス識別問題を解いてくれるようだ。
 学生時代は、quadratic programmingのソルバーだけあって、自分でSVM書いて、カーネルの機能入れて、マルチクラスに対応させて、...とやっていたけど、今はライブラリ使えば中身を理解してなくても識別問題が解けてしまう。嬉しいような寂しいような...

出来たもの
象限ごとにトレーニングデータをラベリングした場合、同心円の半径ごとにトレーニングデータをラベリングした場合の学習結果。


ソースコード
前回書いたbinary SVMのサンプルを少し書き換えただけ。
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>

using namespace cv;

const int WINDOW_HEIGHT = 512;
const int WINDOW_WIDTH = 512;
const int TRAINING_DATA_NUM = 200;
const Vec3b REGION_COLOR[] = {
    Vec3b(0, 0, 128),
    Vec3b(0, 128, 0),
    Vec3b(128, 0, 0),
    Vec3b(0, 64, 128)
};
const Scalar DATA_COLOR[] = {
    Scalar(0, 0, 255),
    Scalar(0, 255, 0),
    Scalar(255, 0, 0),
    Scalar(0, 128, 255)
};

inline int labelOf(double x, double y) {
    // concentric circle
    /*
    double d = (x - 0.5) * (x - 0.5) + (y - 0.5) * (y - 0.5);
    if (d < 0.075) return 0;
    if (d < 0.15) return 1;
    if (d < 0.22) return 2;
    return 3;
    */

    // quadrant
    x -= 0.5, y -= 0.5;
    return (x >= 0) * 2 + (y >= 0);
}

/**
 * マルチクラス非線形SVMのデモ.
 */
int main() {
    // Data for visual representation
    Mat image = Mat::zeros(WINDOW_HEIGHT, WINDOW_WIDTH, CV_8UC3);

    // Set up training data
    Mat trainingData(TRAINING_DATA_NUM, 2, CV_32FC1);
    Mat labels(TRAINING_DATA_NUM, 1, CV_32FC1);

    for (int i = 0; i < TRAINING_DATA_NUM; i++) {
        float x = 1. * rand() / RAND_MAX;
        float y = 1. * rand() / RAND_MAX;
        
        trainingData.at<float>(i, 0) = x;
        trainingData.at<float>(i, 1) = y;

        labels.at<float>(i, 0) = labelOf(x, y);
    }

    // Set up SVM's parameters
    CvSVMParams params = CvSVMParams();

    params.svm_type = CvSVM::C_SVC;
    params.kernel_type = CvSVM::RBF;
    params.degree = 0;
    params.gamma = 20;
    params.coef0 = 0;
    params.C = 10;
    params.nu = 0.0;
    params.p = 0.0;
    params.class_weights = NULL;
    params.term_crit.type = CV_TERMCRIT_ITER + CV_TERMCRIT_EPS;
    params.term_crit.max_iter = 1000;
    params.term_crit.epsilon = 1e-6;

    // Train the SVM
    CvSVM SVM;
    SVM.train(trainingData, labels, Mat(), Mat(), params);

    // Show the decision regions given by the SVM
    for (int i = 0; i < image.rows; ++i) {
        for (int j = 0; j < image.cols; ++j)
        {
            Mat sampleMat = 
                (Mat_<float>(1,2) << 1.*j/WINDOW_HEIGHT, 1.*i/WINDOW_WIDTH);
            int response = SVM.predict(sampleMat);
            image.at<Vec3b>(i,j)  = REGION_COLOR[response];
        }
    }

    // Show the training data
    for (int i = 0; i < TRAINING_DATA_NUM; i++) {
        double x = trainingData.at<float>(i, 0) * WINDOW_HEIGHT;
        double y = trainingData.at<float>(i, 1) * WINDOW_WIDTH;

        int label = labels.at<float>(i, 0);
        circle(image, Point(x, y), 3, DATA_COLOR[label], -1, 8);
    }

    // Show support vectors
    int c = SVM.get_support_vector_count();
    for (int i = 0; i < c; ++i) {
        const float* v = SVM.get_support_vector(i);
        circle(image,  Point(v[0] * WINDOW_WIDTH, v[1] * WINDOW_HEIGHT), 
               6, Scalar(128, 128, 128), 2, 8);
    }

    imshow("Kernel SVM", image);

    waitKey(0);
}

OpenCV日記(9)カーネル法を使った非線形SVM

 OpenCVには機械学習の機能を提供するmlモジュールがあります。この中にSVMがあったので使ってみました。

 線形SVMに関するtutorial[1, 2]はいくつかありましたが、非線形のSVMを扱ったものが無かったのでサンプル実装してみました。パラメータがややこしいですが、公式リファレンス[3]に各パラメータの意味が載っているので必要に応じて参照してください。

出来たもの
赤丸、青丸がtraining dataです。赤領域、青領域がそれぞれSVMによって分離された空間です。灰色の丸で囲っているのがサポートベクトルです。




ソースコード
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>

using namespace cv;

const int WINDOW_HEIGHT = 512;
const int WINDOW_WIDTH = 512;
const int TRAINING_DATA_NUM = 200;

inline int labelOf(double x, double y) {
    // return (x - 0.5) * (x - 0.5) + (y - 0.5) * (y - 0.5) <= 0.1 ? 1 : -1;
    // return y >= sin(10 * x) ? 1 : -1;
    return y >= 10*x*x ? 1 : -1;
}

int main() {
    // Data for visual representation
    Mat image = Mat::zeros(WINDOW_HEIGHT, WINDOW_WIDTH, CV_8UC3);

    // Set up training data
    Mat trainingData(TRAINING_DATA_NUM, 2, CV_32FC1);
    Mat labels(TRAINING_DATA_NUM, 1, CV_32FC1);

    for (int i = 0; i < TRAINING_DATA_NUM; i++) {
        float x = 1. * rand() / RAND_MAX;
        float y = 1. * rand() / RAND_MAX;
        
        trainingData.at<float>(i, 0) = x;
        trainingData.at<float>(i, 1) = y;

        labels.at<float>(i, 0) = labelOf(x, y);
    }

    // Set up SVM's parameters
    CvSVMParams params = CvSVMParams();

    params.svm_type = CvSVM::C_SVC;
    params.kernel_type = CvSVM::RBF;
    params.degree = 0;
    params.gamma = 20;
    params.coef0 = 0;
    params.C = 10;
    params.nu = 0.0;
    params.p = 0.0;
    params.class_weights = NULL;
    params.term_crit.type = CV_TERMCRIT_ITER + CV_TERMCRIT_EPS;
    params.term_crit.max_iter = 1000;
    params.term_crit.epsilon = 1e-6;

    // Train the SVM
    CvSVM SVM;
    SVM.train(trainingData, labels, Mat(), Mat(), params);

    // Show the decision regions given by the SVM
    Vec3b red(0, 0, 128), blue(128, 0, 0);
    for (int i = 0; i < image.rows; ++i) {
        for (int j = 0; j < image.cols; ++j)
        {
            Mat sampleMat = 
                (Mat_<float>(1,2) << 1.*j/WINDOW_HEIGHT, 1.*i/WINDOW_WIDTH);
            float response = SVM.predict(sampleMat);
            if (response == 1)
                image.at<Vec3b>(i,j)  = red;
            else if (response == -1)
                image.at<Vec3b>(i,j)  = blue;
        }
    }

    // Show the training data
    for (int i = 0; i < TRAINING_DATA_NUM; i++) {
        double x = trainingData.at<float>(i, 0) * WINDOW_HEIGHT;
        double y = trainingData.at<float>(i, 1) * WINDOW_WIDTH;

        if (labels.at<float>(i, 0) == 1)
            circle(image, Point(x, y), 3, Scalar(0, 0, 255), -1, 8);
        else
            circle( image, Point(x, y), 3, Scalar(255, 0, 0), -1, 8);
    }

    // Show support vectors
    int c = SVM.get_support_vector_count();
    for (int i = 0; i < c; ++i) {
        const float* v = SVM.get_support_vector(i);
        circle(image,  Point(v[0] * WINDOW_WIDTH, v[1] * WINDOW_HEIGHT), 
               6, Scalar(128, 128, 128), 2, 8);
    }

    imshow("Kernel SVM", image);

    waitKey(0);
}
参考URL
[1] Introduction to Support Vector Machines
[2] Support Vector Machines for Non-Linearly Separable Data
[3] Support Vector Machines — OpenCV 2.4.9.0 documentation

JavaでRSAを実装してみる

 JavaでRSAを実装してみました。といっても基本的なところを実装しただけなので実用性はありません。
 Javaの場合は、JCE(Java Cryptography Extension)という標準APIがあるので、そちらを使う方が実用的です。おまけとして、JCEを使ったソースコードも載せておきます。

アルゴリズム
  1. 大きい素数p, qを選びます。
  2. n = pqとします。このときΦ(n) = (p-1)(q-1)になります。
  3. k = Φ(n)とすると、オイラーの定理よりak = 1 (mod n) (式1)が成り立ちます。ただしaはnと互いに素な整数です。
  4. aed = a(mod n) となるようなe, dを選びます。(式1)より、ed = 1 (mod k)を満たすようなeとdを選べばよいことが分かります。
  5. 暗号化するときは、平文Tに対してC = Te(mod n)とします。(e, n)が暗号化キーとなります。
  6. 復号化するときは、暗号文Cに対してT = Cd(mod n)とします。(d, n)が複合化キーとなります。
シンプルなアルゴリズムですが、思いついた人はすごいですね。e乗したものをd乗すれば元に戻るし、逆にd乗したものをe乗すれば元に戻ります。これは、公開鍵で暗号化したものは秘密鍵で復号でき、秘密鍵で暗号化したものは公開鍵で復号できるということに対応しています。

ソースコード
BigInteger便利ですね。 冪乗高速化とか逆元の計算とか自分でやるつもりでしたが要りませんでした。
package com.kenjih.sample;

import java.math.BigInteger;
import java.util.Random;
import java.util.Scanner;

public class CustomRSASample {
    
    public static void main(String[] args) {
        System.out.println("Input text(ASCII only):");
        Scanner sc = new Scanner(System.in);
        String text = sc.nextLine();

        new CustomRSASample().run(text);
    }
    
    /**
     * RSA暗号化/復号化を行う(ASCII文字のみ対応).
     * 
     */
    public void run(String text) {
        
        // generate RSA keys 
        String[] keys = generateKeys(256);
        System.out.println("common key: " + keys[0]);
        System.out.println("public key: " + keys[1]);
        System.out.println("private key: " + keys[2]);
        
        // encrypt the text
        String encryptedText = encrypt(text, keys[1], keys[0]);
        System.out.println("cipher: " + encryptedText);
        
        // decrypt the text
        String decryptedText = decrypt(encryptedText, keys[2], keys[0]);
        System.out.println("plain: " + decryptedText);
    
    }
    
    /**
     * RSAのキーを生成する.
     * 
     * @return 16進数表記のキーを配列keys[]で返す.
     */
    public String[] generateKeys(int bitLength) {
        
        String[] keys = new String[3];
        Random rnd = new Random();
        
        for (;;) {
            BigInteger p = BigInteger.probablePrime(bitLength >> 1, rnd);
            BigInteger q = BigInteger.probablePrime(bitLength >> 1, rnd);
            
            if (p.equals(q))
                continue;
            
            // totient function of two different prime numbers
            BigInteger phi = p.subtract(BigInteger.ONE).multiply(q.subtract(BigInteger.ONE));
            
            BigInteger e = BigInteger.probablePrime(bitLength, rnd);
            if (e.gcd(phi).equals(BigInteger.ONE)) {
                keys[0] = p.multiply(q).toString(16);      // private/public key (common key)
                keys[1] = e.toString(16);                  // public key
                keys[2] = e.modInverse(phi).toString(16);  // private key
                break;
            }
        }
        
        return keys;
    }
    
    /**
     * 平文を暗号化する.
     * 
     */
    public String encrypt(String text, String publicKey, String commonKey) {
        BigInteger a = encode(text);
        BigInteger e = new BigInteger(publicKey, 16);
        BigInteger n = new BigInteger(commonKey, 16);
        
        return a.modPow(e, n).toString(16);
    }
    
    /**
     * 暗号文を復号化する.
     * 
     */
    public String decrypt(String hexCode, String privateKey, String commonKey) {
        BigInteger a = new BigInteger(hexCode, 16);
        BigInteger d = new BigInteger(privateKey, 16);
        BigInteger n = new BigInteger(commonKey, 16);
        
        return decode(a.modPow(d, n));
    }
    
    /**
     * ASCII文字を256進数のBigIntegerに変換する.
     * 
     */
    private BigInteger encode(String text) {
        BigInteger code = BigInteger.ZERO;
        
        for (int i = 0; i < text.length(); i++) {
            code = code.multiply(BigInteger.valueOf(256));
            code = code.add(BigInteger.valueOf(text.charAt(i)));
        }
            
        return code;
    }
    
    /**
     * 256進数のBigIntegerをASCII文字に変換する.
     * 
     */
    private String decode(BigInteger code) {
        StringBuilder sb = new StringBuilder();
        
        while (code.compareTo(BigInteger.ZERO) > 0) {
            int rem = code.mod(BigInteger.valueOf(256)).shortValue();
            sb.append((char)rem);
            code = code.divide(BigInteger.valueOf(256));
        }
        
        return sb.reverse().toString();
    }

}
実行結果は以下のようになります。
Input text(ASCII only):
hello, world.
common key: a8c7c7712a5ad3ef8244efe9f97e9edaf2586f74b4178075c88a82cf485197b7
public key: a1574aa5e7357321ffbdfffaa1f6135f3b5bbfde9389f9e21766860aaa643d49
private key: 84278732afc3e0d0de5c1acb800576c8ced6c983863c47aabd5671e656feebf9
cipher: 1159ca402742d7a8da249fdfc5265fd80d0146ccb1a3053e49c0961ae84f38f2
plain: hello, world.

JCEを使ったRSAの実装
上のソースコードはあくまでアルゴリズムを理解するためのもの(or プログラムの練習 or 学生の実験用)です。実用上はJCEを使うとよいです。
まず、RSAの実装を提供しているライブラリをインストールします。mavenを使っている場合は、以下をpom.xmlに追記します。
<dependency>
    <groupId>org.bouncycastle</groupId>
    <artifactId>bcprov-jdk16</artifactId>
    <version>1.45</version>
</dependency>
JCEのインターフェースを使って、上のライブラリの機能を使用します。
package com.kenjih.sample;

import java.security.Key;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.SecureRandom;
import java.security.Security;
import java.util.Scanner;

import javax.crypto.Cipher;

public class RSASample {
    public static void main(String[] args) {
        System.out.println("Input text:");
        Scanner sc = new Scanner(System.in);
        String text = sc.nextLine();

     new RSASample().run(text);
    }

    public void run(String text) {       
     Security.addProvider(new org.bouncycastle.jce.provider.BouncyCastleProvider());
        
     byte[] input = text.getBytes();
        
        try {
            Cipher cipher = Cipher.getInstance("RSA/None/NoPadding", "BC");
            SecureRandom random = new SecureRandom();
            
            // Generate RSA keys
            KeyPairGenerator generator = KeyPairGenerator.getInstance("RSA", "BC");
            generator.initialize(1024, random);
            KeyPair pair = generator.generateKeyPair();
            Key publicKey = pair.getPublic();
            Key privateKey = pair.getPrivate();
            System.out.println("public key: " + publicKey);
            System.out.println("private key: " + privateKey);
            
            // Encrypt a plain text
            cipher.init(Cipher.ENCRYPT_MODE, publicKey, random);
            byte[] cipherText = cipher.doFinal(input);
            System.out.println("cipher: " + new String(cipherText));
            
            // Decrypt a cipher text
            cipher.init(Cipher.DECRYPT_MODE, privateKey);
            byte[] plainText = cipher.doFinal(cipherText);
            System.out.println("plain : " + new String(plainText));

        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

2014年4月22日火曜日

OpenCV日記(8)アフィン変換

 アフィン変換のtutorialを読みました。

アフィン変換とは、線形写像 + 平行移動で表すことの出来る操作です。

アフィン変換を使うと、
  • 回転
  • 平行移動
  • スケーリング
  • 反射
などを行うことが出来ます。

 これは面白そうです。ということでプログラムを作ってみました。
  1. 画像を表示。
  2. 画像上の点を選び、それらが変換後どこに写像するかを選ぶ。
  3. 写像元 - 写像先のマッピングが3点えられた時点で、対応するアフィン変換行列を計算。
  4. アフィン変換実行。
ということが出来るようにしました。

以下デモです。



以下ソースコードです。
#include <iostream>

#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"

using namespace cv;

const char* WINDOW_NAME = "Affine Transform Demo";
Mat img, tmp;
Point2f srcTri[3], dstTri[3];
int ptr = 0;
Scalar circleColors[] = {
    Scalar(255, 0, 0),
    Scalar(0, 255, 0),
    Scalar(0, 0, 255)
};

void applyAfineTransform() {

    std::cout << "------------------------------------------" << std::endl;
    std::cout << "Affine Transform begin." << std::endl;
    for (int i = 0; i < 3; i++) {
        std::cout << srcTri[i] << " is mapped to " << dstTri[i] << std::endl;
    }

    /// Get Affine matrix from mapping info
    Mat affine_mat = getAffineTransform(srcTri, dstTri);

    /// Apply Affine transform
    warpAffine(img, img, affine_mat, img.size());
        
    tmp = img.clone();
    imshow(WINDOW_NAME, img);

    std::cout << "Affine Transform end." << std::endl << std::endl;

}

void mouseClick(int event, int x, int y, int flags, void* userdata) {

    if (event != EVENT_LBUTTONDOWN) 
        return;

    Point2f point(x, y);
    
    if (ptr % 2 == 0)
        srcTri[ptr/2] = point;
    else
        dstTri[ptr/2] = point;

    circle(tmp, point, 6, circleColors[ptr/2], -1, 8, 0);
    imshow(WINDOW_NAME, tmp);

    if (++ptr >= 6) {
        ptr = 0;
        applyAfineTransform();
    }

}

int main( int, char** argv ) {

    img = imread(argv[1], 1);
    tmp = img.clone();

    namedWindow(WINDOW_NAME, WINDOW_AUTOSIZE);
    setMouseCallback(WINDOW_NAME, mouseClick, NULL);
    imshow(WINDOW_NAME, img);

    waitKey(0);

    return 0;
}