Search on the blog

2015年6月13日土曜日

PyplotでGradient Descentを可視化

目的関数を三次元グラフにプロット
以下では、目的関数をz = 2x2 + 5y2 - 4xyとする.
まず目的関数をグラフ表示してみる.

from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import numpy as np

# define x and y
x = np.arange(-10, 10, 0.1)
y = np.arange(-10, 10, 0.1)
X, Y = np.meshgrid(x, y)

# define z: z = 2*x*x + 5*y*y - 4*x*y
vfunc = np.vectorize(lambda x, y: 2*x*x + 5*y*y - 4*x*y)
Z = vfunc(X, Y)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10)

plt.show()

等高線のプロット
次に等高線をプロットしてみる.plt.contourの第4引数で等高線の本数を指定できる.

from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import numpy as np

# define x and y
x = np.arange(-10, 10, 0.1)
y = np.arange(-10, 10, 0.1)
X, Y = np.meshgrid(x, y)

# define z: z = 2*x*x + 5*y*y - 4*x*y
vfunc = np.vectorize(lambda x, y: 2*x*x + 5*y*y - 4*x*y)
Z = vfunc(X, Y)

plt.figure()
CS = plt.contour(X, Y, Z, 20, colors='black')

plt.show()

Gradient Descentの可視化
最後に、Gradient Descentで目的関数を最小化する様子をグラフ化してみる.左側のサブグラフに探索点が移動する様子を、右側のサブグラフに目的関数値が減少していく様子を示した.

from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import numpy as np

# define x and y
x = np.arange(-10, 10, 0.1)
y = np.arange(-10, 10, 0.1)
X, Y = np.meshgrid(x, y)

# define z: z = 2*x*x + 5*y*y - 4*x*y
func = lambda x, y: 2*x*x + 5*y*y - 4*x*y

# gradient descent
alpha = 0.03
itr = 50
A = [[2, -2], [-2, 5]]  # matrix expression of function z
xt = [5]
yt = [10]

for _ in range(itr):
    x_cur = xt[-1]
    y_cur = yt[-1]
    dx, dy = 2 * np.dot(A, [x_cur, y_cur])
    xt.append(x_cur - alpha * dx)
    yt.append(y_cur - alpha * dy)

# graph plot
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.set_size_inches(12, 6, forward=True)

Z = np.vectorize(func)(X, Y)
ax1.contour(X, Y, Z, 20, colors='black')
ax1.plot(xt, yt, 'bo')
ax1.set_xlabel('x')
ax1.set_ylabel('y')
ax1.set_title('points plot on contour')

zt = map(lambda (a, b): func(a, b), zip(xt, yt))
ax2.plot(zt, 'r')
ax2.set_xlabel('iteration')
ax2.set_ylabel('z value')
ax2.set_title('z value against iteratoin')

fig.subplots_adjust(wspace = 0.5)
plt.show()

0 件のコメント:

コメントを投稿