目的関数を三次元グラフにプロット
以下では、目的関数をz = 2x2 + 5y2 - 4xyとする.まず目的関数をグラフ表示してみる.
from mpl_toolkits.mplot3d import axes3d import matplotlib.pyplot as plt import numpy as np # define x and y x = np.arange(-10, 10, 0.1) y = np.arange(-10, 10, 0.1) X, Y = np.meshgrid(x, y) # define z: z = 2*x*x + 5*y*y - 4*x*y vfunc = np.vectorize(lambda x, y: 2*x*x + 5*y*y - 4*x*y) Z = vfunc(X, Y) fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10) plt.show()
等高線のプロット
次に等高線をプロットしてみる.plt.contourの第4引数で等高線の本数を指定できる.
from mpl_toolkits.mplot3d import axes3d import matplotlib.pyplot as plt import numpy as np # define x and y x = np.arange(-10, 10, 0.1) y = np.arange(-10, 10, 0.1) X, Y = np.meshgrid(x, y) # define z: z = 2*x*x + 5*y*y - 4*x*y vfunc = np.vectorize(lambda x, y: 2*x*x + 5*y*y - 4*x*y) Z = vfunc(X, Y) plt.figure() CS = plt.contour(X, Y, Z, 20, colors='black') plt.show()
Gradient Descentの可視化
最後に、Gradient Descentで目的関数を最小化する様子をグラフ化してみる.左側のサブグラフに探索点が移動する様子を、右側のサブグラフに目的関数値が減少していく様子を示した.
from mpl_toolkits.mplot3d import axes3d import matplotlib.pyplot as plt import numpy as np # define x and y x = np.arange(-10, 10, 0.1) y = np.arange(-10, 10, 0.1) X, Y = np.meshgrid(x, y) # define z: z = 2*x*x + 5*y*y - 4*x*y func = lambda x, y: 2*x*x + 5*y*y - 4*x*y # gradient descent alpha = 0.03 itr = 50 A = [[2, -2], [-2, 5]] # matrix expression of function z xt = [5] yt = [10] for _ in range(itr): x_cur = xt[-1] y_cur = yt[-1] dx, dy = 2 * np.dot(A, [x_cur, y_cur]) xt.append(x_cur - alpha * dx) yt.append(y_cur - alpha * dy) # graph plot fig, (ax1, ax2) = plt.subplots(1, 2) fig.set_size_inches(12, 6, forward=True) Z = np.vectorize(func)(X, Y) ax1.contour(X, Y, Z, 20, colors='black') ax1.plot(xt, yt, 'bo') ax1.set_xlabel('x') ax1.set_ylabel('y') ax1.set_title('points plot on contour') zt = map(lambda (a, b): func(a, b), zip(xt, yt)) ax2.plot(zt, 'r') ax2.set_xlabel('iteration') ax2.set_ylabel('z value') ax2.set_title('z value against iteratoin') fig.subplots_adjust(wspace = 0.5) plt.show()
0 件のコメント:
コメントを投稿