Search on the blog

2013年3月9日土曜日

Octaveで線形SVMを解く

OctaveでハードマージンSVMを解くプログラムを書いてみました。主問題を解きます。
#--------------------------------------------------------------------
# procedure:
#   solve a hard-margin prime svm
# input: 
#   x feature vectors of training data
#   y labels of training data
# output:
#   ret [w; b] where w x + b = 0 is the obtained hyperplane
# -------------------------------------------------------------------
function ret = svm(x, y)
    l = rows(x);           # number of training data
    n = columns(x);        # demension of feature vectors

    x0 = zeros(n+1, 1);
    H = eye(n+1);
    H(n+1, n+1) = 0;
    q = zeros(n+1, 1);
    A = [];
    b = [];
    lb = [];
    ub = [];
    Ai = [x, ones(l, 1)];
    Al = -inf * ones(l, 1);
    Au = inf * ones(l, 1);    
    for i = 1:l
        if y(i) == 1
            Al(i) = 1;
        else
            Au(i) = -1; 
        endif  
    endfor

    [sol, obj, info, lambda] = qp (x0, H, q, A, b, lb, ub, Al, Ai, Au);
    ret = sol;
    
endfunction;
動作確認用のユーティリティ関数を書いてみました。
まず学習データをプロットする関数。
#--------------------------------------------------------------------
# procedure:
#   plot training data
# input: 
#   x feature vectors of training data
#   y label of training data
# output:
#   n/a
# -------------------------------------------------------------------
function plotData(x, y)
    l = rows(x);  # number of training data
    hold on;
    for i = 1:l
        if y(i) == 1
            plot (x(i, 1), x(i, 2), 'ro', 'markersize', 10);
        else
            plot (x(i, 1), x(i, 2), 'go', 'markersize', 10);
        endif
    endfor
endfunction;
そして、直線wx + b = 0をプロットする関数。
#--------------------------------------------------------------------
# procedure:
#   plot a line w x + b = 0
# input: 
#   w weight vector
#   b bias
# output:
#   n/a
# -------------------------------------------------------------------
function plotLine(w, b)
    ezplot (@(x, y) w(1) * x + w(2) * y + b);
    title ("");
endfunction;
使ってみます。
octave> x = [
> 0 0 
> 1 0
> 0 1
> ];
octave> y = [
> 1
> -1
> -1
> ];
octave> plotData(x, y);
octave> ret = svm(x, y);
octave> plotLine(ret(1:2), ret(3));
とやると、以下のようなグラフが表示されます。

ふむ。ちゃんと出来てるっぽいですね。
他の学習データパターンでもやってみました。



0 件のコメント:

コメントを投稿