#-------------------------------------------------------------------- # procedure: # solve a hard-margin prime svm # input: # x feature vectors of training data # y labels of training data # output: # ret [w; b] where w x + b = 0 is the obtained hyperplane # ------------------------------------------------------------------- function ret = svm(x, y) l = rows(x); # number of training data n = columns(x); # demension of feature vectors x0 = zeros(n+1, 1); H = eye(n+1); H(n+1, n+1) = 0; q = zeros(n+1, 1); A = []; b = []; lb = []; ub = []; Ai = [x, ones(l, 1)]; Al = -inf * ones(l, 1); Au = inf * ones(l, 1); for i = 1:l if y(i) == 1 Al(i) = 1; else Au(i) = -1; endif endfor [sol, obj, info, lambda] = qp (x0, H, q, A, b, lb, ub, Al, Ai, Au); ret = sol; endfunction;動作確認用のユーティリティ関数を書いてみました。
まず学習データをプロットする関数。
#-------------------------------------------------------------------- # procedure: # plot training data # input: # x feature vectors of training data # y label of training data # output: # n/a # ------------------------------------------------------------------- function plotData(x, y) l = rows(x); # number of training data hold on; for i = 1:l if y(i) == 1 plot (x(i, 1), x(i, 2), 'ro', 'markersize', 10); else plot (x(i, 1), x(i, 2), 'go', 'markersize', 10); endif endfor endfunction;そして、直線wx + b = 0をプロットする関数。
#-------------------------------------------------------------------- # procedure: # plot a line w x + b = 0 # input: # w weight vector # b bias # output: # n/a # ------------------------------------------------------------------- function plotLine(w, b) ezplot (@(x, y) w(1) * x + w(2) * y + b); title (""); endfunction;使ってみます。
octave> x = [ > 0 0 > 1 0 > 0 1 > ]; octave> y = [ > 1 > -1 > -1 > ]; octave> plotData(x, y); octave> ret = svm(x, y); octave> plotLine(ret(1:2), ret(3));とやると、以下のようなグラフが表示されます。
ふむ。ちゃんと出来てるっぽいですね。
他の学習データパターンでもやってみました。
0 件のコメント:
コメントを投稿