## 2013年3月9日土曜日

### Octaveで線形SVMを解く

OctaveでハードマージンSVMを解くプログラムを書いてみました。主問題を解きます。
```#--------------------------------------------------------------------
# procedure:
#   solve a hard-margin prime svm
# input:
#   x feature vectors of training data
#   y labels of training data
# output:
#   ret [w; b] where w x + b = 0 is the obtained hyperplane
# -------------------------------------------------------------------
function ret = svm(x, y)
l = rows(x);           # number of training data
n = columns(x);        # demension of feature vectors

x0 = zeros(n+1, 1);
H = eye(n+1);
H(n+1, n+1) = 0;
q = zeros(n+1, 1);
A = [];
b = [];
lb = [];
ub = [];
Ai = [x, ones(l, 1)];
Al = -inf * ones(l, 1);
Au = inf * ones(l, 1);
for i = 1:l
if y(i) == 1
Al(i) = 1;
else
Au(i) = -1;
endif
endfor

[sol, obj, info, lambda] = qp (x0, H, q, A, b, lb, ub, Al, Ai, Au);
ret = sol;

endfunction;
```

まず学習データをプロットする関数。
```#--------------------------------------------------------------------
# procedure:
#   plot training data
# input:
#   x feature vectors of training data
#   y label of training data
# output:
#   n/a
# -------------------------------------------------------------------
function plotData(x, y)
l = rows(x);  # number of training data
hold on;
for i = 1:l
if y(i) == 1
plot (x(i, 1), x(i, 2), 'ro', 'markersize', 10);
else
plot (x(i, 1), x(i, 2), 'go', 'markersize', 10);
endif
endfor
endfunction;
```
そして、直線wx + b = 0をプロットする関数。
```#--------------------------------------------------------------------
# procedure:
#   plot a line w x + b = 0
# input:
#   w weight vector
#   b bias
# output:
#   n/a
# -------------------------------------------------------------------
function plotLine(w, b)
ezplot (@(x, y) w(1) * x + w(2) * y + b);
title ("");
endfunction;
```

```octave> x = [
> 0 0
> 1 0
> 0 1
> ];
octave> y = [
> 1
> -1
> -1
> ];
octave> plotData(x, y);
octave> ret = svm(x, y);
octave> plotLine(ret(1:2), ret(3));
```
とやると、以下のようなグラフが表示されます。

ふむ。ちゃんと出来てるっぽいですね。